Add files using upload-large-folder tool
Browse files- .dev_scripts/build_docs.sh +8 -0
- .github/ISSUE_TEMPLATE/bug_report.md +19 -0
- .ipynb_checkpoints/README-checkpoint.md +423 -0
- COT_TRAIN.jsonl +0 -0
- docs/transformers/build/lib/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py +209 -0
- docs/transformers/build/lib/transformers/models/rwkv/modeling_rwkv.py +850 -0
- docs/transformers/build/lib/transformers/models/sam/__init__.py +30 -0
- docs/transformers/build/lib/transformers/models/sam/configuration_sam.py +337 -0
- docs/transformers/build/lib/transformers/models/sam/image_processing_sam.py +1494 -0
- docs/transformers/build/lib/transformers/models/sam/modeling_sam.py +1579 -0
- docs/transformers/build/lib/transformers/models/sam/modeling_tf_sam.py +1726 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t/__init__.py +31 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t/configuration_seamless_m4t.py +416 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +309 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t/tokenization_seamless_m4t.py +567 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t_v2/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +425 -0
- test.sh +6 -0
- test_qwenOmni.py +177 -0
- train.sh +64 -0
.dev_scripts/build_docs.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pip install -r requirements/docs.txt
|
| 2 |
+
cd docs
|
| 3 |
+
rm -rf build
|
| 4 |
+
|
| 5 |
+
# update api rst
|
| 6 |
+
#rm -rf source/api/
|
| 7 |
+
#sphinx-apidoc --module-first -o source/api/ ../modelscope/
|
| 8 |
+
make html
|
.github/ISSUE_TEMPLATE/bug_report.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Bug report
|
| 3 |
+
about: Create a report to help us improve
|
| 4 |
+
title: ''
|
| 5 |
+
labels: ''
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Describe the bug**
|
| 11 |
+
What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
**Your hardware and system info**
|
| 15 |
+
Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
**Additional context**
|
| 19 |
+
Add any other context about the problem here(在这里补充其他信息)
|
.ipynb_checkpoints/README-checkpoint.md
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<br>
|
| 5 |
+
<img src="asset/banner.png"/>
|
| 6 |
+
<br>
|
| 7 |
+
<p>
|
| 8 |
+
<p align="center">
|
| 9 |
+
<a href="https://modelscope.cn/home">ModelScope Community Website</a>
|
| 10 |
+
<br>
|
| 11 |
+
<a href="README_CN.md">中文</a>   |   English  
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
<img src="https://img.shields.io/badge/python-3.10-5be.svg">
|
| 16 |
+
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
|
| 17 |
+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
|
| 18 |
+
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
|
| 19 |
+
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
|
| 20 |
+
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
|
| 21 |
+
<a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
<p align="center">
|
| 25 |
+
<a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 26 |
+
</p>
|
| 27 |
+
|
| 28 |
+
<p align="center">
|
| 29 |
+
<a href="https://arxiv.org/abs/2408.05517">Paper</a>   | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a>   |   <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a>  
|
| 30 |
+
</p>
|
| 31 |
+
|
| 32 |
+
## 📖 Table of Contents
|
| 33 |
+
- [Groups](#-Groups)
|
| 34 |
+
- [Introduction](#-introduction)
|
| 35 |
+
- [News](#-news)
|
| 36 |
+
- [Installation](#%EF%B8%8F-installation)
|
| 37 |
+
- [Quick Start](#-quick-Start)
|
| 38 |
+
- [Usage](#-Usage)
|
| 39 |
+
- [License](#-License)
|
| 40 |
+
- [Citation](#-citation)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## ☎ Groups
|
| 44 |
+
|
| 45 |
+
You can contact us and communicate with us by adding our group:
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | WeChat Group
|
| 49 |
+
:-------------------------:|:-------------------------:
|
| 50 |
+
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
## 📝 Introduction
|
| 54 |
+
🍲 ms-swift is an official framework provided by the ModelScope community for fine-tuning and deploying large language models and multi-modal large models. It currently supports the training (pre-training, fine-tuning, human alignment), inference, evaluation, quantization, and deployment of 500+ large models and 200+ multi-modal large models. These large language models (LLMs) include models such as Qwen3, Qwen3-MoE, Qwen2.5, InternLM3, GLM4, Mistral, DeepSeek-R1, Yi1.5, TeleChat2, Baichuan2, and Gemma2. The multi-modal LLMs include models such as Qwen2.5-VL, Qwen2-Audio, Llama3.4, Llava, InternVL2.5, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL2, Phi3.5-Vision, and GOT-OCR2.
|
| 55 |
+
|
| 56 |
+
🍔 Additionally, ms-swift incorporates the latest training technologies, including lightweight techniques such as LoRA, QLoRA, Llama-Pro, LongLoRA, GaLore, Q-GaLore, LoRA+, LISA, DoRA, FourierFt, ReFT, UnSloth, and Liger, as well as human alignment training methods like DPO, GRPO, RM, PPO, KTO, CPO, SimPO, and ORPO. ms-swift supports acceleration of inference, evaluation, and deployment modules using vLLM and LMDeploy, and it supports model quantization with technologies like GPTQ, AWQ, and BNB. Furthermore, ms-swift offers a Gradio-based Web UI and a wealth of best practices.
|
| 57 |
+
|
| 58 |
+
**Why choose ms-swift?**
|
| 59 |
+
|
| 60 |
+
- 🍎 **Model Types**: Supports 500+ pure text large models, **200+ multi-modal large models**, as well as All-to-All multi-modal models, sequence classification models, and embedding models, **covering the entire process from training to deployment**.
|
| 61 |
+
- **Dataset Types**: Comes with 150+ pre-training, fine-tuning, human alignment, multi-modal datasets, and supports custom datasets.
|
| 62 |
+
- **Hardware Support**: Compatible with CPU, RTX series, T4/V100, A10/A100/H100, Ascend NPU, MPS, etc.
|
| 63 |
+
- 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel.
|
| 64 |
+
- **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques.
|
| 65 |
+
- **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
|
| 66 |
+
- **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models.
|
| 67 |
+
- 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
|
| 68 |
+
- **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
|
| 69 |
+
- **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
|
| 70 |
+
- 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment.
|
| 71 |
+
- **Inference Acceleration**: Supports inference acceleration engines like PyTorch, vLLM, LmDeploy, and provides OpenAI API for accelerating inference, deployment, and evaluation modules.
|
| 72 |
+
- **Model Evaluation**: Uses EvalScope as the evaluation backend and supports evaluation on 100+ datasets for both pure text and multi-modal models.
|
| 73 |
+
- **Model Quantization**: Supports AWQ, GPTQ, and BNB quantized exports, with models that can use vLLM/LmDeploy for inference acceleration and continue training.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## 🎉 News
|
| 77 |
+
- 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models) .
|
| 78 |
+
- 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383).
|
| 79 |
+
- 🎁 2025.03.23: Multi-round GRPO is now supported for training multi-turn dialogue scenarios (e.g., agent tool calling). Please refer to the [training script](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_multi_round.sh).
|
| 80 |
+
- 🎁 2025.03.16: Support for Megatron's parallel training techniques is now available. Please see the [Megatron-SWIFT training documentation](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html).
|
| 81 |
+
- 🎁 2025.03.15: Fine-tuning of embedding models for both pure text and multimodal models is supported. Please check the [training script](https://idealab.alibaba-inc.com/examples/train/embedding).
|
| 82 |
+
- 🎁 2025.03.05: The hybrid mode for GRPO is supported, with a script for training a 72B model on 4 GPUs (4*80G) available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_72b_4gpu.sh). Tensor parallelism with vllm is also supported, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/multi_gpu_mp_colocate.sh).
|
| 83 |
+
- 🎁 2025.02.21: The GRPO algorithm now supports LMDeploy, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/full_lmdeploy.sh). Additionally, the performance of the GRPO algorithm has been tested, achieving a training speed increase of up to 300% using various tricks. Please check the WanDB table [here](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz).
|
| 84 |
+
- 🎁 2025.02.21: The `swift sample` command is now supported. The reinforcement fine-tuning script can be found [here](https://idealab.alibaba-inc.com/docs/source/Instruction/强化微调.md), and the large model API distillation sampling script is available [here](https://idealab.alibaba-inc.com/examples/sampler/distill/distill.sh).
|
| 85 |
+
- 🔥 2025.02.12: Support for the GRPO (Group Relative Policy Optimization) training algorithm has been added. Documentation is available [here](https://idealab.alibaba-inc.com/docs/source/Instruction/GRPO.md).
|
| 86 |
+
- 🎁 2024.12.04: Major update to **ms-swift 3.0**. Please refer to the [release notes and changes](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html).
|
| 87 |
+
<details><summary>More</summary>
|
| 88 |
+
|
| 89 |
+
- 🎉 2024.08.12: The ms-swift paper has been published on arXiv and can be read [here](https://arxiv.org/abs/2408.05517).
|
| 90 |
+
- 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
|
| 91 |
+
- 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`.
|
| 92 |
+
- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO.
|
| 93 |
+
- 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf).
|
| 94 |
+
</details>
|
| 95 |
+
|
| 96 |
+
## 🛠️ Installation
|
| 97 |
+
To install using pip:
|
| 98 |
+
```shell
|
| 99 |
+
pip install ms-swift -U
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
To install from source:
|
| 103 |
+
```shell
|
| 104 |
+
# pip install git+https://github.com/modelscope/ms-swift.git
|
| 105 |
+
|
| 106 |
+
git clone https://github.com/modelscope/ms-swift.git
|
| 107 |
+
cd ms-swift
|
| 108 |
+
pip install -e .
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Running Environment:
|
| 112 |
+
|
| 113 |
+
| | Range | Recommended | Notes |
|
| 114 |
+
| ------------ |--------------| ----------- | ----------------------------------------- |
|
| 115 |
+
| python | >=3.9 | 3.10 | |
|
| 116 |
+
| cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
|
| 117 |
+
| torch | >=2.0 | | |
|
| 118 |
+
| transformers | >=4.33 | 4.51 | |
|
| 119 |
+
| modelscope | >=1.23 | | |
|
| 120 |
+
| peft | >=0.11,<0.16 | ||
|
| 121 |
+
| trl | >=0.13,<0.18 | 0.17 |RLHF|
|
| 122 |
+
| deepspeed | >=0.14 | 0.14.5 | Training |
|
| 123 |
+
| vllm | >=0.5.1 | 0.7.3/0.8 | Inference/Deployment/Evaluation |
|
| 124 |
+
| lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |
|
| 125 |
+
| evalscope | >=0.11 | | Evaluation |
|
| 126 |
+
|
| 127 |
+
For more optional dependencies, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh).
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
## 🚀 Quick Start
|
| 131 |
+
|
| 132 |
+
10 minutes of self-cognition fine-tuning of Qwen2.5-7B-Instruct on a single 3090 GPU:
|
| 133 |
+
|
| 134 |
+
### Command Line Interface
|
| 135 |
+
|
| 136 |
+
```shell
|
| 137 |
+
# 22GB
|
| 138 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 139 |
+
swift sft \
|
| 140 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 141 |
+
--train_type lora \
|
| 142 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
| 143 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
| 144 |
+
'swift/self-cognition#500' \
|
| 145 |
+
--torch_dtype bfloat16 \
|
| 146 |
+
--num_train_epochs 1 \
|
| 147 |
+
--per_device_train_batch_size 1 \
|
| 148 |
+
--per_device_eval_batch_size 1 \
|
| 149 |
+
--learning_rate 1e-4 \
|
| 150 |
+
--lora_rank 8 \
|
| 151 |
+
--lora_alpha 32 \
|
| 152 |
+
--target_modules all-linear \
|
| 153 |
+
--gradient_accumulation_steps 16 \
|
| 154 |
+
--eval_steps 50 \
|
| 155 |
+
--save_steps 50 \
|
| 156 |
+
--save_total_limit 2 \
|
| 157 |
+
--logging_steps 5 \
|
| 158 |
+
--max_length 2048 \
|
| 159 |
+
--output_dir output \
|
| 160 |
+
--system 'You are a helpful assistant.' \
|
| 161 |
+
--warmup_ratio 0.05 \
|
| 162 |
+
--dataloader_num_workers 4 \
|
| 163 |
+
--model_author swift \
|
| 164 |
+
--model_name swift-robot
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
Tips:
|
| 168 |
+
|
| 169 |
+
- If you want to train with a custom dataset, you can refer to [this guide](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) to organize your dataset format and specify `--dataset <dataset_path>`.
|
| 170 |
+
- The `--model_author` and `--model_name` parameters are only effective when the dataset includes `swift/self-cognition`.
|
| 171 |
+
- To train with a different model, simply modify `--model <model_id/model_path>`.
|
| 172 |
+
- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
|
| 173 |
+
|
| 174 |
+
After training is complete, use the following command to infer with the trained weights:
|
| 175 |
+
|
| 176 |
+
- Here, `--adapters` should be replaced with the last checkpoint folder generated during training. Since the adapters folder contains the training parameter file `args.json`, there is no need to specify `--model`, `--system` separately; Swift will automatically read these parameters. To disable this behavior, you can set `--load_args false`.
|
| 177 |
+
|
| 178 |
+
```shell
|
| 179 |
+
# Using an interactive command line for inference.
|
| 180 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 181 |
+
swift infer \
|
| 182 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 183 |
+
--stream true \
|
| 184 |
+
--temperature 0 \
|
| 185 |
+
--max_new_tokens 2048
|
| 186 |
+
|
| 187 |
+
# merge-lora and use vLLM for inference acceleration
|
| 188 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 189 |
+
swift infer \
|
| 190 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 191 |
+
--stream true \
|
| 192 |
+
--merge_lora true \
|
| 193 |
+
--infer_backend vllm \
|
| 194 |
+
--max_model_len 8192 \
|
| 195 |
+
--temperature 0 \
|
| 196 |
+
--max_new_tokens 2048
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
Finally, use the following command to push the model to ModelScope:
|
| 200 |
+
|
| 201 |
+
```shell
|
| 202 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 203 |
+
swift export \
|
| 204 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 205 |
+
--push_to_hub true \
|
| 206 |
+
--hub_model_id '<your-model-id>' \
|
| 207 |
+
--hub_token '<your-sdk-token>' \
|
| 208 |
+
--use_hf false
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
### Web-UI
|
| 213 |
+
The Web-UI is a **zero-threshold** training and deployment interface solution based on Gradio interface technology. For more details, you can check [here](https://swift.readthedocs.io/en/latest/GetStarted/Web-UI.html).
|
| 214 |
+
|
| 215 |
+
```shell
|
| 216 |
+
SWIFT_UI_LANG=en swift web-ui
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+

|
| 220 |
+
|
| 221 |
+
### Using Python
|
| 222 |
+
|
| 223 |
+
ms-swift also supports training and inference using Python. Below is pseudocode for training and inference. For more details, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb).
|
| 224 |
+
|
| 225 |
+
Training:
|
| 226 |
+
|
| 227 |
+
```python
|
| 228 |
+
# Retrieve the model and template, and add a trainable LoRA module
|
| 229 |
+
model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
|
| 230 |
+
template = get_template(model.model_meta.template, tokenizer, ...)
|
| 231 |
+
model = Swift.prepare_model(model, lora_config)
|
| 232 |
+
|
| 233 |
+
# Download and load the dataset, and encode the text into tokens
|
| 234 |
+
train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
|
| 235 |
+
train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
|
| 236 |
+
val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
|
| 237 |
+
|
| 238 |
+
# Train the model
|
| 239 |
+
trainer = Seq2SeqTrainer(
|
| 240 |
+
model=model,
|
| 241 |
+
args=training_args,
|
| 242 |
+
data_collator=template.data_collator,
|
| 243 |
+
train_dataset=train_dataset,
|
| 244 |
+
eval_dataset=val_dataset,
|
| 245 |
+
template=template,
|
| 246 |
+
)
|
| 247 |
+
trainer.train()
|
| 248 |
+
```
|
| 249 |
+
Inference:
|
| 250 |
+
|
| 251 |
+
```python
|
| 252 |
+
# Perform inference using the native PyTorch engine
|
| 253 |
+
engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
|
| 254 |
+
infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
|
| 255 |
+
request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
|
| 256 |
+
|
| 257 |
+
resp_list = engine.infer([infer_request], request_config)
|
| 258 |
+
print(f'response: {resp_list[0].choices[0].message.content}')
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
## ✨ Usage
|
| 262 |
+
Here is a minimal example of training to deployment using ms-swift. For more details, you can check the [examples](https://github.com/modelscope/ms-swift/tree/main/examples).
|
| 263 |
+
|
| 264 |
+
- If you want to use other models or datasets (including multimodal models and datasets), you only need to modify `--model` to specify the corresponding model's ID or path, and modify `--dataset` to specify the corresponding dataset's ID or path.
|
| 265 |
+
- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
|
| 266 |
+
|
| 267 |
+
| Useful Links |
|
| 268 |
+
| ------ |
|
| 269 |
+
| [🔥Command Line Parameters](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html) |
|
| 270 |
+
| [Supported Models and Datasets](https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html) |
|
| 271 |
+
| [Custom Models](https://swift.readthedocs.io/en/latest/Customization/Custom-model.html), [🔥Custom Datasets](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) |
|
| 272 |
+
| [LLM Tutorial](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
|
| 273 |
+
|
| 274 |
+
### Training
|
| 275 |
+
|
| 276 |
+
Supported Training Methods:
|
| 277 |
+
|
| 278 |
+
| Method | Full-Parameter | LoRA | QLoRA | Deepspeed | Multi-Node | Multi-Modal |
|
| 279 |
+
|------------------------------------|--------------------------------------------------------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|----------------------------------------------------------------------------------------------|
|
| 280 |
+
| Pre-training | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| 281 |
+
| Instruction Supervised Fine-tuning | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
|
| 282 |
+
| DPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
|
| 283 |
+
| GRPO Training | [✅]((https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh)) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
|
| 284 |
+
| Reward Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
|
| 285 |
+
| PPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
|
| 286 |
+
| KTO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
|
| 287 |
+
| CPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
|
| 288 |
+
| SimPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
|
| 289 |
+
| ORPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
|
| 290 |
+
| Classification Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
|
| 291 |
+
| Embedding Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
Pre-training:
|
| 296 |
+
```shell
|
| 297 |
+
# 8*A100
|
| 298 |
+
NPROC_PER_NODE=8 \
|
| 299 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 300 |
+
swift pt \
|
| 301 |
+
--model Qwen/Qwen2.5-7B \
|
| 302 |
+
--dataset swift/chinese-c4 \
|
| 303 |
+
--streaming true \
|
| 304 |
+
--train_type full \
|
| 305 |
+
--deepspeed zero2 \
|
| 306 |
+
--output_dir output \
|
| 307 |
+
--max_steps 10000 \
|
| 308 |
+
...
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
Fine-tuning:
|
| 312 |
+
```shell
|
| 313 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 314 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 315 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-en \
|
| 316 |
+
--train_type lora \
|
| 317 |
+
--output_dir output \
|
| 318 |
+
...
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
RLHF:
|
| 322 |
+
```shell
|
| 323 |
+
CUDA_VISIBLE_DEVICES=0 swift rlhf \
|
| 324 |
+
--rlhf_type dpo \
|
| 325 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 326 |
+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
|
| 327 |
+
--train_type lora \
|
| 328 |
+
--output_dir output \
|
| 329 |
+
...
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
### Inference
|
| 334 |
+
```shell
|
| 335 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 336 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 337 |
+
--stream true \
|
| 338 |
+
--infer_backend pt \
|
| 339 |
+
--max_new_tokens 2048
|
| 340 |
+
|
| 341 |
+
# LoRA
|
| 342 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 343 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 344 |
+
--adapters swift/test_lora \
|
| 345 |
+
--stream true \
|
| 346 |
+
--infer_backend pt \
|
| 347 |
+
--temperature 0 \
|
| 348 |
+
--max_new_tokens 2048
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
### Interface Inference
|
| 352 |
+
```shell
|
| 353 |
+
CUDA_VISIBLE_DEVICES=0 swift app \
|
| 354 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 355 |
+
--stream true \
|
| 356 |
+
--infer_backend pt \
|
| 357 |
+
--max_new_tokens 2048
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
### Deployment
|
| 361 |
+
```shell
|
| 362 |
+
CUDA_VISIBLE_DEVICES=0 swift deploy \
|
| 363 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 364 |
+
--infer_backend vllm
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
### Sampling
|
| 368 |
+
```shell
|
| 369 |
+
CUDA_VISIBLE_DEVICES=0 swift sample \
|
| 370 |
+
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
|
| 371 |
+
--sampler_engine pt \
|
| 372 |
+
--num_return_sequences 5 \
|
| 373 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
### Evaluation
|
| 377 |
+
```shell
|
| 378 |
+
CUDA_VISIBLE_DEVICES=0 swift eval \
|
| 379 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 380 |
+
--infer_backend lmdeploy \
|
| 381 |
+
--eval_backend OpenCompass \
|
| 382 |
+
--eval_dataset ARC_c
|
| 383 |
+
```
|
| 384 |
+
|
| 385 |
+
### Quantization
|
| 386 |
+
```shell
|
| 387 |
+
CUDA_VISIBLE_DEVICES=0 swift export \
|
| 388 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 389 |
+
--quant_bits 4 --quant_method awq \
|
| 390 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 391 |
+
--output_dir Qwen2.5-7B-Instruct-AWQ
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
### Push Model
|
| 395 |
+
```shell
|
| 396 |
+
swift export \
|
| 397 |
+
--model <model-path> \
|
| 398 |
+
--push_to_hub true \
|
| 399 |
+
--hub_model_id '<model-id>' \
|
| 400 |
+
--hub_token '<sdk-token>'
|
| 401 |
+
```
|
| 402 |
+
|
| 403 |
+
## 🏛 License
|
| 404 |
+
|
| 405 |
+
This framework is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). For models and datasets, please refer to the original resource page and follow the corresponding License.
|
| 406 |
+
|
| 407 |
+
## 📎 Citation
|
| 408 |
+
|
| 409 |
+
```bibtex
|
| 410 |
+
@misc{zhao2024swiftascalablelightweightinfrastructure,
|
| 411 |
+
title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
|
| 412 |
+
author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
|
| 413 |
+
year={2024},
|
| 414 |
+
eprint={2408.05517},
|
| 415 |
+
archivePrefix={arXiv},
|
| 416 |
+
primaryClass={cs.CL},
|
| 417 |
+
url={https://arxiv.org/abs/2408.05517},
|
| 418 |
+
}
|
| 419 |
+
```
|
| 420 |
+
|
| 421 |
+
## Star History
|
| 422 |
+
|
| 423 |
+
[](https://star-history.com/#modelscope/ms-swift&Date)
|
COT_TRAIN.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/transformers/build/lib/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert a RWKV checkpoint from BlinkDL to the Hugging Face format."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import gc
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from huggingface_hub import hf_hub_download, split_torch_state_dict_into_shards
|
| 25 |
+
|
| 26 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, RwkvConfig
|
| 27 |
+
from transformers.modeling_utils import WEIGHTS_INDEX_NAME
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
NUM_HIDDEN_LAYERS_MAPPING = {
|
| 31 |
+
"169M": 12,
|
| 32 |
+
"430M": 24,
|
| 33 |
+
"1B5": 24,
|
| 34 |
+
"3B": 32,
|
| 35 |
+
"7B": 32,
|
| 36 |
+
"14B": 40,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
HIDEN_SIZE_MAPPING = {
|
| 40 |
+
"169M": 768,
|
| 41 |
+
"430M": 1024,
|
| 42 |
+
"1B5": 2048,
|
| 43 |
+
"3B": 2560,
|
| 44 |
+
"7B": 4096,
|
| 45 |
+
"14B": 5120,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def convert_state_dict(state_dict):
|
| 50 |
+
state_dict_keys = list(state_dict.keys())
|
| 51 |
+
for name in state_dict_keys:
|
| 52 |
+
weight = state_dict.pop(name)
|
| 53 |
+
# emb -> embedding
|
| 54 |
+
if name.startswith("emb."):
|
| 55 |
+
name = name.replace("emb.", "embeddings.")
|
| 56 |
+
# ln_0 -> pre_ln (only present at block 0)
|
| 57 |
+
if name.startswith("blocks.0.ln0"):
|
| 58 |
+
name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
|
| 59 |
+
# att -> attention
|
| 60 |
+
name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
|
| 61 |
+
# ffn -> feed_forward
|
| 62 |
+
name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
|
| 63 |
+
# time_mix_k -> time_mix_key and reshape
|
| 64 |
+
if name.endswith(".time_mix_k"):
|
| 65 |
+
name = name.replace(".time_mix_k", ".time_mix_key")
|
| 66 |
+
# time_mix_v -> time_mix_value and reshape
|
| 67 |
+
if name.endswith(".time_mix_v"):
|
| 68 |
+
name = name.replace(".time_mix_v", ".time_mix_value")
|
| 69 |
+
# time_mix_r -> time_mix_key and reshape
|
| 70 |
+
if name.endswith(".time_mix_r"):
|
| 71 |
+
name = name.replace(".time_mix_r", ".time_mix_receptance")
|
| 72 |
+
|
| 73 |
+
if name != "head.weight":
|
| 74 |
+
name = "rwkv." + name
|
| 75 |
+
|
| 76 |
+
state_dict[name] = weight
|
| 77 |
+
return state_dict
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def convert_rmkv_checkpoint_to_hf_format(
|
| 81 |
+
repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None, push_to_hub=False, model_name=None
|
| 82 |
+
):
|
| 83 |
+
# 1. If possible, build the tokenizer.
|
| 84 |
+
if tokenizer_file is None:
|
| 85 |
+
print("No `--tokenizer_file` provided, we will use the default tokenizer.")
|
| 86 |
+
vocab_size = 50277
|
| 87 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
| 88 |
+
else:
|
| 89 |
+
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
|
| 90 |
+
vocab_size = len(tokenizer)
|
| 91 |
+
tokenizer.save_pretrained(output_dir)
|
| 92 |
+
|
| 93 |
+
# 2. Build the config
|
| 94 |
+
possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys())
|
| 95 |
+
if size is None:
|
| 96 |
+
# Try to infer size from the checkpoint name
|
| 97 |
+
for candidate in possible_sizes:
|
| 98 |
+
if candidate in checkpoint_file:
|
| 99 |
+
size = candidate
|
| 100 |
+
break
|
| 101 |
+
if size is None:
|
| 102 |
+
raise ValueError("Could not infer the size, please provide it with the `--size` argument.")
|
| 103 |
+
if size not in possible_sizes:
|
| 104 |
+
raise ValueError(f"`size` should be one of {possible_sizes}, got {size}.")
|
| 105 |
+
|
| 106 |
+
config = RwkvConfig(
|
| 107 |
+
vocab_size=vocab_size,
|
| 108 |
+
num_hidden_layers=NUM_HIDDEN_LAYERS_MAPPING[size],
|
| 109 |
+
hidden_size=HIDEN_SIZE_MAPPING[size],
|
| 110 |
+
)
|
| 111 |
+
config.save_pretrained(output_dir)
|
| 112 |
+
|
| 113 |
+
# 3. Download model file then convert state_dict
|
| 114 |
+
model_file = hf_hub_download(repo_id, checkpoint_file)
|
| 115 |
+
state_dict = torch.load(model_file, map_location="cpu", weights_only=True)
|
| 116 |
+
state_dict = convert_state_dict(state_dict)
|
| 117 |
+
|
| 118 |
+
# 4. Split in shards and save
|
| 119 |
+
state_dict_split = split_torch_state_dict_into_shards(state_dict)
|
| 120 |
+
shards = index = None
|
| 121 |
+
for tensors in state_dict_split.filename_to_tensors.values():
|
| 122 |
+
shards = {tensor: state_dict[tensor] for tensor in tensors}
|
| 123 |
+
if state_dict_split.is_sharded:
|
| 124 |
+
index = {
|
| 125 |
+
"metadata": state_dict_split.metadata,
|
| 126 |
+
"weight_map": state_dict_split.tensor_to_filename,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
for shard_file, shard in shards.items():
|
| 130 |
+
torch.save(shard, os.path.join(output_dir, shard_file))
|
| 131 |
+
|
| 132 |
+
if index is not None:
|
| 133 |
+
save_index_file = os.path.join(output_dir, WEIGHTS_INDEX_NAME)
|
| 134 |
+
# Save the index as well
|
| 135 |
+
with open(save_index_file, "w", encoding="utf-8") as f:
|
| 136 |
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
| 137 |
+
f.write(content)
|
| 138 |
+
|
| 139 |
+
# 5. Clean up shards (for some reason the file PyTorch saves take the same space as the whole state_dict
|
| 140 |
+
print(
|
| 141 |
+
"Cleaning up shards. This may error with an OOM error, it this is the case don't worry you still have converted the model."
|
| 142 |
+
)
|
| 143 |
+
shard_files = list(shards.keys())
|
| 144 |
+
|
| 145 |
+
del state_dict
|
| 146 |
+
del shards
|
| 147 |
+
gc.collect()
|
| 148 |
+
|
| 149 |
+
for shard_file in shard_files:
|
| 150 |
+
state_dict = torch.load(os.path.join(output_dir, shard_file), weights_only=True)
|
| 151 |
+
torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, os.path.join(output_dir, shard_file))
|
| 152 |
+
|
| 153 |
+
del state_dict
|
| 154 |
+
gc.collect()
|
| 155 |
+
|
| 156 |
+
if push_to_hub:
|
| 157 |
+
if model_name is None:
|
| 158 |
+
raise ValueError("Please provide a `model_name` to push the model to the Hub.")
|
| 159 |
+
model = AutoModelForCausalLM.from_pretrained(output_dir)
|
| 160 |
+
model.push_to_hub(model_name, max_shard_size="2GB")
|
| 161 |
+
tokenizer.push_to_hub(model_name)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
parser = argparse.ArgumentParser()
|
| 166 |
+
# Required parameters
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--repo_id", default=None, type=str, required=True, help="Repo ID from which to pull the checkpoint."
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--checkpoint_file", default=None, type=str, required=True, help="Name of the checkpoint file in the repo."
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--output_dir", default=None, type=str, required=True, help="Where to save the converted model."
|
| 175 |
+
)
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--tokenizer_file",
|
| 178 |
+
default=None,
|
| 179 |
+
type=str,
|
| 180 |
+
help="Path to the tokenizer file to use (if not provided, only the model is converted).",
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--size",
|
| 184 |
+
default=None,
|
| 185 |
+
type=str,
|
| 186 |
+
help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--push_to_hub",
|
| 190 |
+
action="store_true",
|
| 191 |
+
help="Push to the Hub the converted model.",
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--model_name",
|
| 195 |
+
default=None,
|
| 196 |
+
type=str,
|
| 197 |
+
help="Name of the pushed model on the Hub, including the username / organization.",
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
args = parser.parse_args()
|
| 201 |
+
convert_rmkv_checkpoint_to_hf_format(
|
| 202 |
+
args.repo_id,
|
| 203 |
+
args.checkpoint_file,
|
| 204 |
+
args.output_dir,
|
| 205 |
+
size=args.size,
|
| 206 |
+
tokenizer_file=args.tokenizer_file,
|
| 207 |
+
push_to_hub=args.push_to_hub,
|
| 208 |
+
model_name=args.model_name,
|
| 209 |
+
)
|
docs/transformers/build/lib/transformers/models/rwkv/modeling_rwkv.py
ADDED
|
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Bo Peng and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch RWKV model."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
from ...generation import GenerationMixin
|
| 28 |
+
from ...modeling_utils import PreTrainedModel
|
| 29 |
+
from ...utils import (
|
| 30 |
+
ModelOutput,
|
| 31 |
+
add_code_sample_docstrings,
|
| 32 |
+
add_start_docstrings,
|
| 33 |
+
add_start_docstrings_to_model_forward,
|
| 34 |
+
is_bitsandbytes_available,
|
| 35 |
+
is_ninja_available,
|
| 36 |
+
is_torch_cuda_available,
|
| 37 |
+
logging,
|
| 38 |
+
)
|
| 39 |
+
from .configuration_rwkv import RwkvConfig
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__)
|
| 43 |
+
|
| 44 |
+
_CHECKPOINT_FOR_DOC = "RWKV/rwkv-4-169m-pile"
|
| 45 |
+
_CONFIG_FOR_DOC = "RwkvConfig"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
rwkv_cuda_kernel = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_wkv_cuda_kernel(context_length):
|
| 52 |
+
from torch.utils.cpp_extension import load as load_kernel
|
| 53 |
+
|
| 54 |
+
global rwkv_cuda_kernel
|
| 55 |
+
|
| 56 |
+
kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv"
|
| 57 |
+
cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]]
|
| 58 |
+
|
| 59 |
+
# Only load the kernel if it's not been loaded yet or if we changed the context length
|
| 60 |
+
if rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == context_length:
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
logger.info(f"Loading CUDA kernel for RWKV at context length of {context_length}.")
|
| 64 |
+
|
| 65 |
+
flags = [
|
| 66 |
+
"-res-usage",
|
| 67 |
+
"--maxrregcount 60",
|
| 68 |
+
"--use_fast_math",
|
| 69 |
+
"-O3",
|
| 70 |
+
"-Xptxas -O3",
|
| 71 |
+
"--extra-device-vectorization",
|
| 72 |
+
f"-DTmax={context_length}",
|
| 73 |
+
]
|
| 74 |
+
rwkv_cuda_kernel = load_kernel(
|
| 75 |
+
name=f"wkv_{context_length}",
|
| 76 |
+
sources=cuda_kernel_files,
|
| 77 |
+
verbose=(logging.get_verbosity() == logging.DEBUG),
|
| 78 |
+
extra_cuda_cflags=flags,
|
| 79 |
+
)
|
| 80 |
+
rwkv_cuda_kernel.max_seq_length = context_length
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class RwkvLinearAttention(torch.autograd.Function):
|
| 84 |
+
@staticmethod
|
| 85 |
+
def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):
|
| 86 |
+
batch_size, seq_len, hidden_size = key.size()
|
| 87 |
+
if seq_len > rwkv_cuda_kernel.max_seq_length:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
|
| 90 |
+
f"{rwkv_cuda_kernel.max_seq_length} with this model."
|
| 91 |
+
)
|
| 92 |
+
if batch_size * hidden_size % min(hidden_size, 32) != 0:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
|
| 95 |
+
f"multiple of {min(hidden_size, 32)}."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
ctx.input_dtype = key.dtype
|
| 99 |
+
|
| 100 |
+
if (
|
| 101 |
+
time_decay.device.type != "cuda"
|
| 102 |
+
or time_first.device.type != "cuda"
|
| 103 |
+
or key.device.type != "cuda"
|
| 104 |
+
or value.device.type != "cuda"
|
| 105 |
+
):
|
| 106 |
+
raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")
|
| 107 |
+
|
| 108 |
+
time_decay = -torch.exp(time_decay.float().contiguous())
|
| 109 |
+
if key.dtype == torch.float16:
|
| 110 |
+
time_first = time_first.float()
|
| 111 |
+
key = key.float()
|
| 112 |
+
value = value.float()
|
| 113 |
+
time_first = time_first.contiguous()
|
| 114 |
+
key = key.contiguous()
|
| 115 |
+
value = value.contiguous()
|
| 116 |
+
# The CUDA kernel will fill this tensor.
|
| 117 |
+
output = torch.empty_like(key, memory_format=torch.contiguous_format)
|
| 118 |
+
if return_state or state is not None:
|
| 119 |
+
if state is None:
|
| 120 |
+
state = torch.zeros(
|
| 121 |
+
batch_size,
|
| 122 |
+
hidden_size,
|
| 123 |
+
3,
|
| 124 |
+
dtype=torch.float32,
|
| 125 |
+
device=key.device,
|
| 126 |
+
memory_format=torch.contiguous_format,
|
| 127 |
+
)
|
| 128 |
+
state[:, :, 2] -= 1e38
|
| 129 |
+
else:
|
| 130 |
+
state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()
|
| 131 |
+
if key.dtype == torch.bfloat16:
|
| 132 |
+
forward_func = rwkv_cuda_kernel.forward_with_state_bf16
|
| 133 |
+
else:
|
| 134 |
+
forward_func = rwkv_cuda_kernel.forward_with_state
|
| 135 |
+
forward_func(time_decay, time_first, key, value, output, state)
|
| 136 |
+
else:
|
| 137 |
+
forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
|
| 138 |
+
forward_func(time_decay, time_first, key, value, output)
|
| 139 |
+
|
| 140 |
+
ctx.save_for_backward(time_decay, time_first, key, value, output)
|
| 141 |
+
|
| 142 |
+
if state is not None:
|
| 143 |
+
state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]
|
| 144 |
+
|
| 145 |
+
return output.to(ctx.input_dtype), state
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
# g stands for grad
|
| 149 |
+
def backward(ctx, g_output, g_state=None):
|
| 150 |
+
input_dtype = ctx.input_dtype
|
| 151 |
+
|
| 152 |
+
time_decay, time_first, key, value, output = ctx.saved_tensors
|
| 153 |
+
# The CUDA kernel will fill those tensors.
|
| 154 |
+
g_time_decay = torch.empty_like(
|
| 155 |
+
time_decay,
|
| 156 |
+
memory_format=torch.contiguous_format,
|
| 157 |
+
dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
|
| 158 |
+
)
|
| 159 |
+
g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format)
|
| 160 |
+
g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
|
| 161 |
+
g_value = torch.empty_like(value, memory_format=torch.contiguous_format)
|
| 162 |
+
|
| 163 |
+
if input_dtype == torch.float16:
|
| 164 |
+
g_output = g_output.float()
|
| 165 |
+
backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
|
| 166 |
+
backward_func(
|
| 167 |
+
time_decay,
|
| 168 |
+
time_first,
|
| 169 |
+
key,
|
| 170 |
+
value,
|
| 171 |
+
output,
|
| 172 |
+
g_output.contiguous(),
|
| 173 |
+
g_time_decay,
|
| 174 |
+
g_time_first,
|
| 175 |
+
g_key,
|
| 176 |
+
g_value,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
return (
|
| 180 |
+
g_time_decay.to(input_dtype),
|
| 181 |
+
g_time_first.to(input_dtype),
|
| 182 |
+
g_key.to(input_dtype),
|
| 183 |
+
g_value.to(input_dtype),
|
| 184 |
+
None,
|
| 185 |
+
None,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False):
|
| 190 |
+
# For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
|
| 191 |
+
# within a torch.no_grad.
|
| 192 |
+
_, seq_length, _ = key.size()
|
| 193 |
+
output = torch.zeros_like(key)
|
| 194 |
+
|
| 195 |
+
if state is None:
|
| 196 |
+
num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
|
| 197 |
+
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
|
| 198 |
+
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
|
| 199 |
+
else:
|
| 200 |
+
num_state, den_state, max_state = state
|
| 201 |
+
# For numerical stability
|
| 202 |
+
# real_numerator_state = num_state * torch.exp(max_state)
|
| 203 |
+
# real_denominator_state = den_state * torch.exp(max_state)
|
| 204 |
+
|
| 205 |
+
time_decay = -torch.exp(time_decay)
|
| 206 |
+
|
| 207 |
+
for current_index in range(seq_length):
|
| 208 |
+
current_key = key[:, current_index].float()
|
| 209 |
+
current_value = value[:, current_index]
|
| 210 |
+
|
| 211 |
+
# wkv computation at time t
|
| 212 |
+
max_for_output = torch.maximum(max_state, current_key + time_first)
|
| 213 |
+
e1 = torch.exp(max_state - max_for_output)
|
| 214 |
+
e2 = torch.exp(current_key + time_first - max_for_output)
|
| 215 |
+
numerator = e1 * num_state + e2 * current_value
|
| 216 |
+
denominator = e1 * den_state + e2
|
| 217 |
+
output[:, current_index] = (numerator / denominator).to(output.dtype)
|
| 218 |
+
|
| 219 |
+
# Update state for next iteration
|
| 220 |
+
max_for_state = torch.maximum(max_state + time_decay, current_key)
|
| 221 |
+
e1 = torch.exp(max_state + time_decay - max_for_state)
|
| 222 |
+
e2 = torch.exp(current_key - max_for_state)
|
| 223 |
+
num_state = e1 * num_state + e2 * current_value
|
| 224 |
+
den_state = e1 * den_state + e2
|
| 225 |
+
max_state = max_for_state
|
| 226 |
+
|
| 227 |
+
if return_state or state is not None:
|
| 228 |
+
state = [num_state, den_state, max_state]
|
| 229 |
+
|
| 230 |
+
return output, state
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False):
|
| 234 |
+
no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
|
| 235 |
+
# Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
|
| 236 |
+
# in this case).
|
| 237 |
+
one_token = key.size(1) == 1
|
| 238 |
+
if rwkv_cuda_kernel is None or no_cuda or one_token:
|
| 239 |
+
return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state)
|
| 240 |
+
else:
|
| 241 |
+
return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class RwkvSelfAttention(nn.Module):
|
| 245 |
+
def __init__(self, config, layer_id=0):
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.config = config
|
| 248 |
+
kernel_loaded = rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == config.context_length
|
| 249 |
+
if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
|
| 250 |
+
try:
|
| 251 |
+
load_wkv_cuda_kernel(config.context_length)
|
| 252 |
+
except Exception:
|
| 253 |
+
logger.info("Could not load the custom CUDA kernel for RWKV attention.")
|
| 254 |
+
self.layer_id = layer_id
|
| 255 |
+
hidden_size = config.hidden_size
|
| 256 |
+
attention_hidden_size = (
|
| 257 |
+
config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
|
| 258 |
+
)
|
| 259 |
+
self.attention_hidden_size = attention_hidden_size
|
| 260 |
+
|
| 261 |
+
self.time_decay = nn.Parameter(torch.empty(attention_hidden_size))
|
| 262 |
+
self.time_first = nn.Parameter(torch.empty(attention_hidden_size))
|
| 263 |
+
|
| 264 |
+
self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 265 |
+
self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 266 |
+
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 267 |
+
|
| 268 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
| 269 |
+
self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
| 270 |
+
self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
| 271 |
+
self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
| 272 |
+
self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
|
| 273 |
+
|
| 274 |
+
# TODO: maybe jit, otherwise move inside forward
|
| 275 |
+
def extract_key_value(self, hidden, state=None):
|
| 276 |
+
# Mix hidden with the previous timestep to produce key, value, receptance
|
| 277 |
+
if hidden.size(1) == 1 and state is not None:
|
| 278 |
+
shifted = state[1][:, :, self.layer_id]
|
| 279 |
+
else:
|
| 280 |
+
shifted = self.time_shift(hidden)
|
| 281 |
+
if state is not None:
|
| 282 |
+
shifted[:, 0] = state[1][:, :, self.layer_id]
|
| 283 |
+
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
| 284 |
+
value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
|
| 285 |
+
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
| 286 |
+
|
| 287 |
+
key = self.key(key)
|
| 288 |
+
value = self.value(value)
|
| 289 |
+
receptance = torch.sigmoid(self.receptance(receptance))
|
| 290 |
+
if state is not None:
|
| 291 |
+
state[1][:, :, self.layer_id] = hidden[:, -1]
|
| 292 |
+
return receptance, key, value, state
|
| 293 |
+
|
| 294 |
+
def forward(self, hidden, state=None, use_cache=False):
|
| 295 |
+
receptance, key, value, state = self.extract_key_value(hidden, state=state)
|
| 296 |
+
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
|
| 297 |
+
rwkv, layer_state = rwkv_linear_attention(
|
| 298 |
+
self.time_decay,
|
| 299 |
+
self.time_first,
|
| 300 |
+
key,
|
| 301 |
+
value,
|
| 302 |
+
state=layer_state,
|
| 303 |
+
return_state=use_cache,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
if layer_state is not None:
|
| 307 |
+
state[2][:, :, self.layer_id] = layer_state[0]
|
| 308 |
+
state[3][:, :, self.layer_id] = layer_state[1]
|
| 309 |
+
state[4][:, :, self.layer_id] = layer_state[2]
|
| 310 |
+
|
| 311 |
+
return self.output(receptance * rwkv), state
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class RwkvFeedForward(nn.Module):
|
| 315 |
+
def __init__(self, config, layer_id=0):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.config = config
|
| 318 |
+
self.layer_id = layer_id
|
| 319 |
+
hidden_size = config.hidden_size
|
| 320 |
+
intermediate_size = (
|
| 321 |
+
config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
| 325 |
+
self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 326 |
+
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 327 |
+
|
| 328 |
+
self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 329 |
+
self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 330 |
+
self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 331 |
+
|
| 332 |
+
def forward(self, hidden, state=None):
|
| 333 |
+
if hidden.size(1) == 1 and state is not None:
|
| 334 |
+
shifted = state[0][:, :, self.layer_id]
|
| 335 |
+
else:
|
| 336 |
+
shifted = self.time_shift(hidden)
|
| 337 |
+
if state is not None:
|
| 338 |
+
shifted[:, 0] = state[0][:, :, self.layer_id]
|
| 339 |
+
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
| 340 |
+
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
| 341 |
+
|
| 342 |
+
key = torch.square(torch.relu(self.key(key)))
|
| 343 |
+
value = self.value(key)
|
| 344 |
+
receptance = torch.sigmoid(self.receptance(receptance))
|
| 345 |
+
|
| 346 |
+
if state is not None:
|
| 347 |
+
state[0][:, :, self.layer_id] = hidden[:, -1]
|
| 348 |
+
|
| 349 |
+
return receptance * value, state
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class RwkvBlock(nn.Module):
|
| 353 |
+
def __init__(self, config, layer_id):
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.config = config
|
| 356 |
+
self.layer_id = layer_id
|
| 357 |
+
|
| 358 |
+
if layer_id == 0:
|
| 359 |
+
self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 360 |
+
|
| 361 |
+
self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 362 |
+
self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 363 |
+
|
| 364 |
+
self.attention = RwkvSelfAttention(config, layer_id)
|
| 365 |
+
self.feed_forward = RwkvFeedForward(config, layer_id)
|
| 366 |
+
|
| 367 |
+
def forward(self, hidden, state=None, use_cache=False, output_attentions=False):
|
| 368 |
+
if self.layer_id == 0:
|
| 369 |
+
hidden = self.pre_ln(hidden)
|
| 370 |
+
|
| 371 |
+
attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
|
| 372 |
+
hidden = hidden + attention
|
| 373 |
+
|
| 374 |
+
feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
|
| 375 |
+
hidden = hidden + feed_forward
|
| 376 |
+
|
| 377 |
+
outputs = (hidden, state)
|
| 378 |
+
if output_attentions:
|
| 379 |
+
outputs += (attention,)
|
| 380 |
+
else:
|
| 381 |
+
outputs += (None,)
|
| 382 |
+
|
| 383 |
+
return outputs
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class RwkvPreTrainedModel(PreTrainedModel):
|
| 387 |
+
"""
|
| 388 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 389 |
+
models.
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
config_class = RwkvConfig
|
| 393 |
+
base_model_prefix = "rwkv"
|
| 394 |
+
_no_split_modules = ["RwkvBlock"]
|
| 395 |
+
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
| 396 |
+
supports_gradient_checkpointing = True
|
| 397 |
+
_is_stateful = True
|
| 398 |
+
|
| 399 |
+
def _init_weights(self, module):
|
| 400 |
+
"""Initialize the weights."""
|
| 401 |
+
if isinstance(module, RwkvSelfAttention):
|
| 402 |
+
layer_id = module.layer_id
|
| 403 |
+
num_hidden_layers = module.config.num_hidden_layers
|
| 404 |
+
hidden_size = module.config.hidden_size
|
| 405 |
+
attention_hidden_size = module.attention_hidden_size
|
| 406 |
+
|
| 407 |
+
ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
|
| 408 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
|
| 409 |
+
|
| 410 |
+
time_weight = torch.tensor(
|
| 411 |
+
[i / hidden_size for i in range(hidden_size)],
|
| 412 |
+
dtype=module.time_mix_key.dtype,
|
| 413 |
+
device=module.time_mix_key.device,
|
| 414 |
+
)
|
| 415 |
+
time_weight = time_weight[None, None, :]
|
| 416 |
+
|
| 417 |
+
decay_speed = [
|
| 418 |
+
-5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
| 419 |
+
for h in range(attention_hidden_size)
|
| 420 |
+
]
|
| 421 |
+
decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
|
| 422 |
+
zigzag = (
|
| 423 |
+
torch.tensor(
|
| 424 |
+
[(i + 1) % 3 - 1 for i in range(attention_hidden_size)],
|
| 425 |
+
dtype=module.time_first.dtype,
|
| 426 |
+
device=module.time_first.device,
|
| 427 |
+
)
|
| 428 |
+
* 0.5
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
with torch.no_grad():
|
| 432 |
+
module.time_decay.data = decay_speed
|
| 433 |
+
module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag)
|
| 434 |
+
|
| 435 |
+
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
| 436 |
+
module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
| 437 |
+
module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
|
| 438 |
+
elif isinstance(module, RwkvFeedForward):
|
| 439 |
+
layer_id = module.layer_id
|
| 440 |
+
num_hidden_layers = module.config.num_hidden_layers
|
| 441 |
+
hidden_size = module.config.hidden_size
|
| 442 |
+
|
| 443 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
|
| 444 |
+
|
| 445 |
+
time_weight = torch.tensor(
|
| 446 |
+
[i / hidden_size for i in range(hidden_size)],
|
| 447 |
+
dtype=module.time_mix_key.dtype,
|
| 448 |
+
device=module.time_mix_key.device,
|
| 449 |
+
)
|
| 450 |
+
time_weight = time_weight[None, None, :]
|
| 451 |
+
|
| 452 |
+
with torch.no_grad():
|
| 453 |
+
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
| 454 |
+
module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
@dataclass
|
| 458 |
+
class RwkvOutput(ModelOutput):
|
| 459 |
+
"""
|
| 460 |
+
Class for the RWKV model outputs.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 464 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 465 |
+
state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
|
| 466 |
+
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
| 467 |
+
avoid providing the old `input_ids`.
|
| 468 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 469 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 470 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 471 |
+
|
| 472 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 473 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 474 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 475 |
+
sequence_length)`.
|
| 476 |
+
|
| 477 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 478 |
+
heads.
|
| 479 |
+
"""
|
| 480 |
+
|
| 481 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 482 |
+
state: Optional[List[torch.FloatTensor]] = None
|
| 483 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 484 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
@dataclass
|
| 488 |
+
class RwkvCausalLMOutput(ModelOutput):
|
| 489 |
+
"""
|
| 490 |
+
Base class for causal language model (or autoregressive) outputs.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 494 |
+
Language modeling loss (for next-token prediction).
|
| 495 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 496 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 497 |
+
state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
|
| 498 |
+
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
| 499 |
+
avoid providing the old `input_ids`.
|
| 500 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 501 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 502 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 503 |
+
|
| 504 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 505 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 506 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 507 |
+
sequence_length)`.
|
| 508 |
+
|
| 509 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 510 |
+
heads.
|
| 511 |
+
"""
|
| 512 |
+
|
| 513 |
+
loss: Optional[torch.FloatTensor] = None
|
| 514 |
+
logits: Optional[torch.FloatTensor] = None
|
| 515 |
+
state: Optional[List[torch.FloatTensor]] = None
|
| 516 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 517 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
RWKV_START_DOCSTRING = r"""
|
| 521 |
+
|
| 522 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 523 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 524 |
+
etc.)
|
| 525 |
+
|
| 526 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 527 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 528 |
+
and behavior.
|
| 529 |
+
|
| 530 |
+
Parameters:
|
| 531 |
+
config ([`RwkvConfig`]): Model configuration class with all the parameters of the model.
|
| 532 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 533 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
RWKV_INPUTS_DOCSTRING = r"""
|
| 537 |
+
Args:
|
| 538 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
| 539 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
| 540 |
+
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
| 541 |
+
sequence tokens in the vocabulary.
|
| 542 |
+
|
| 543 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
| 544 |
+
`input_ids`.
|
| 545 |
+
|
| 546 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 547 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 548 |
+
|
| 549 |
+
[What are input IDs?](../glossary#input-ids)
|
| 550 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
| 551 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 552 |
+
|
| 553 |
+
- 1 for tokens that are **not masked**,
|
| 554 |
+
- 0 for tokens that are **masked**.
|
| 555 |
+
|
| 556 |
+
This is currently not used by `RwkvModel`, but will be supported in the future.
|
| 557 |
+
|
| 558 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 559 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 560 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 561 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 562 |
+
model's internal embedding lookup matrix.
|
| 563 |
+
state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
|
| 564 |
+
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
| 565 |
+
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
| 566 |
+
use_cache (`bool`, *optional*):
|
| 567 |
+
If set to `True`, the last state is returned and can be used to quickly generate the next logits.
|
| 568 |
+
output_attentions (`bool`, *optional*):
|
| 569 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 570 |
+
tensors for more detail.
|
| 571 |
+
output_hidden_states (`bool`, *optional*):
|
| 572 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 573 |
+
more detail.
|
| 574 |
+
return_dict (`bool`, *optional*):
|
| 575 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
@add_start_docstrings(
|
| 580 |
+
"The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.",
|
| 581 |
+
RWKV_START_DOCSTRING,
|
| 582 |
+
)
|
| 583 |
+
class RwkvModel(RwkvPreTrainedModel):
|
| 584 |
+
def __init__(self, config):
|
| 585 |
+
super().__init__(config)
|
| 586 |
+
|
| 587 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 588 |
+
self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
|
| 589 |
+
self.ln_out = nn.LayerNorm(config.hidden_size)
|
| 590 |
+
|
| 591 |
+
self.layers_are_rescaled = False
|
| 592 |
+
|
| 593 |
+
self.gradient_checkpointing = False
|
| 594 |
+
|
| 595 |
+
# Initialize weights and apply final processing
|
| 596 |
+
self.post_init()
|
| 597 |
+
|
| 598 |
+
def get_input_embeddings(self):
|
| 599 |
+
return self.embeddings
|
| 600 |
+
|
| 601 |
+
def set_input_embeddings(self, new_embeddings):
|
| 602 |
+
self.embeddings = new_embeddings
|
| 603 |
+
|
| 604 |
+
@add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
|
| 605 |
+
@add_code_sample_docstrings(
|
| 606 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 607 |
+
output_type=RwkvOutput,
|
| 608 |
+
config_class=_CONFIG_FOR_DOC,
|
| 609 |
+
)
|
| 610 |
+
def forward(
|
| 611 |
+
self,
|
| 612 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 613 |
+
attention_mask: Optional[torch.LongTensor] = None, # noqa
|
| 614 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 615 |
+
state: Optional[List[torch.FloatTensor]] = None,
|
| 616 |
+
use_cache: Optional[bool] = None,
|
| 617 |
+
output_attentions: Optional[bool] = None,
|
| 618 |
+
output_hidden_states: Optional[bool] = None,
|
| 619 |
+
return_dict: Optional[bool] = None,
|
| 620 |
+
) -> Union[Tuple, RwkvOutput]:
|
| 621 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 622 |
+
output_hidden_states = (
|
| 623 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 624 |
+
)
|
| 625 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
| 626 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 627 |
+
|
| 628 |
+
if attention_mask is not None:
|
| 629 |
+
logger.warning_once("`attention_mask` was passed, but it is unused in this model.")
|
| 630 |
+
|
| 631 |
+
if self.training == self.layers_are_rescaled:
|
| 632 |
+
self._rescale_layers()
|
| 633 |
+
|
| 634 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 635 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 636 |
+
elif input_ids is None and inputs_embeds is None:
|
| 637 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 638 |
+
|
| 639 |
+
if inputs_embeds is None:
|
| 640 |
+
inputs_embeds = self.embeddings(input_ids)
|
| 641 |
+
|
| 642 |
+
if use_cache and state is None:
|
| 643 |
+
shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers)
|
| 644 |
+
state = [
|
| 645 |
+
torch.zeros(
|
| 646 |
+
*shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device
|
| 647 |
+
)
|
| 648 |
+
for i in range(5)
|
| 649 |
+
]
|
| 650 |
+
state[4] -= 1e30
|
| 651 |
+
|
| 652 |
+
if self.gradient_checkpointing and self.training:
|
| 653 |
+
if use_cache:
|
| 654 |
+
logger.warning_once(
|
| 655 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 656 |
+
)
|
| 657 |
+
use_cache = False
|
| 658 |
+
|
| 659 |
+
hidden_states = inputs_embeds
|
| 660 |
+
|
| 661 |
+
all_self_attentions = () if output_attentions else None
|
| 662 |
+
all_hidden_states = () if output_hidden_states else None
|
| 663 |
+
for idx, block in enumerate(self.blocks):
|
| 664 |
+
if self.gradient_checkpointing and self.training:
|
| 665 |
+
hidden_states, state, attentions = self._gradient_checkpointing_func(
|
| 666 |
+
block.__call__, hidden_states, state, use_cache, output_attentions
|
| 667 |
+
)
|
| 668 |
+
else:
|
| 669 |
+
hidden_states, state, attentions = block(
|
| 670 |
+
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
if (
|
| 674 |
+
self.layers_are_rescaled
|
| 675 |
+
and self.config.rescale_every > 0
|
| 676 |
+
and (idx + 1) % self.config.rescale_every == 0
|
| 677 |
+
):
|
| 678 |
+
hidden_states = hidden_states / 2
|
| 679 |
+
|
| 680 |
+
if output_hidden_states:
|
| 681 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 682 |
+
|
| 683 |
+
if output_attentions:
|
| 684 |
+
all_self_attentions = all_self_attentions + (attentions,)
|
| 685 |
+
|
| 686 |
+
hidden_states = self.ln_out(hidden_states)
|
| 687 |
+
|
| 688 |
+
if output_hidden_states:
|
| 689 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 690 |
+
|
| 691 |
+
if not return_dict:
|
| 692 |
+
return tuple(x for x in [hidden_states, state, all_hidden_states, all_self_attentions] if x is not None)
|
| 693 |
+
|
| 694 |
+
return RwkvOutput(
|
| 695 |
+
last_hidden_state=hidden_states,
|
| 696 |
+
state=state,
|
| 697 |
+
hidden_states=all_hidden_states,
|
| 698 |
+
attentions=all_self_attentions,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
def _rescale_layers(self):
|
| 702 |
+
# Layers should be rescaled for inference only.
|
| 703 |
+
if self.layers_are_rescaled == (not self.training):
|
| 704 |
+
return
|
| 705 |
+
if self.config.rescale_every > 0:
|
| 706 |
+
with torch.no_grad():
|
| 707 |
+
for block_id, block in enumerate(self.blocks):
|
| 708 |
+
if self.training:
|
| 709 |
+
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
| 710 |
+
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
| 711 |
+
else:
|
| 712 |
+
# Deal with quantization statistics
|
| 713 |
+
if hasattr(block.attention.output.weight, "SCB"):
|
| 714 |
+
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
| 715 |
+
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
| 716 |
+
elif hasattr(block.attention.output.weight, "quant_state"):
|
| 717 |
+
self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
|
| 718 |
+
self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
|
| 719 |
+
else:
|
| 720 |
+
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
| 721 |
+
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
| 722 |
+
|
| 723 |
+
self.layers_are_rescaled = not self.training
|
| 724 |
+
|
| 725 |
+
def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
|
| 726 |
+
r"""
|
| 727 |
+
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
|
| 728 |
+
be quantized again.
|
| 729 |
+
"""
|
| 730 |
+
if not is_bitsandbytes_available():
|
| 731 |
+
raise ImportError("Please install bitsandbytes to use this method.")
|
| 732 |
+
import bitsandbytes as bnb
|
| 733 |
+
|
| 734 |
+
dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
|
| 735 |
+
|
| 736 |
+
dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
|
| 737 |
+
|
| 738 |
+
# re-quantize the model:
|
| 739 |
+
# we need to put it first on CPU then back to the device
|
| 740 |
+
# this will create an overhead :/
|
| 741 |
+
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
|
| 742 |
+
# bugs with bnb
|
| 743 |
+
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
|
| 744 |
+
setattr(target_layer, "weight", quant_weight)
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
@add_start_docstrings(
|
| 748 |
+
"""
|
| 749 |
+
The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
| 750 |
+
embeddings).
|
| 751 |
+
""",
|
| 752 |
+
RWKV_START_DOCSTRING,
|
| 753 |
+
)
|
| 754 |
+
class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin):
|
| 755 |
+
_tied_weights_keys = ["head.weight"]
|
| 756 |
+
|
| 757 |
+
def __init__(self, config):
|
| 758 |
+
super().__init__(config)
|
| 759 |
+
self.rwkv = RwkvModel(config)
|
| 760 |
+
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 761 |
+
|
| 762 |
+
# Initialize weights and apply final processing
|
| 763 |
+
self.post_init()
|
| 764 |
+
|
| 765 |
+
def get_output_embeddings(self):
|
| 766 |
+
return self.head
|
| 767 |
+
|
| 768 |
+
def set_output_embeddings(self, new_embeddings):
|
| 769 |
+
self.head = new_embeddings
|
| 770 |
+
|
| 771 |
+
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs):
|
| 772 |
+
# Overwritten -- this model uses `state`, but doesn't have a cache (`past_key_values`)
|
| 773 |
+
|
| 774 |
+
# only last token for inputs_ids if the state is passed along.
|
| 775 |
+
if state is not None:
|
| 776 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 777 |
+
|
| 778 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 779 |
+
if inputs_embeds is not None and state is None:
|
| 780 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 781 |
+
else:
|
| 782 |
+
model_inputs = {"input_ids": input_ids}
|
| 783 |
+
|
| 784 |
+
model_inputs["state"] = state
|
| 785 |
+
model_inputs["use_cache"] = use_cache
|
| 786 |
+
return model_inputs
|
| 787 |
+
|
| 788 |
+
@add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
|
| 789 |
+
@add_code_sample_docstrings(
|
| 790 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 791 |
+
output_type=RwkvCausalLMOutput,
|
| 792 |
+
config_class=_CONFIG_FOR_DOC,
|
| 793 |
+
)
|
| 794 |
+
def forward(
|
| 795 |
+
self,
|
| 796 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 797 |
+
attention_mask: Optional[torch.LongTensor] = None, # noqa
|
| 798 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 799 |
+
state: Optional[List[torch.FloatTensor]] = None,
|
| 800 |
+
labels: Optional[torch.LongTensor] = None,
|
| 801 |
+
use_cache: Optional[bool] = None,
|
| 802 |
+
output_attentions: Optional[bool] = None,
|
| 803 |
+
output_hidden_states: Optional[bool] = None,
|
| 804 |
+
return_dict: Optional[bool] = None,
|
| 805 |
+
**kwargs,
|
| 806 |
+
) -> Union[Tuple, RwkvCausalLMOutput]:
|
| 807 |
+
r"""
|
| 808 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 809 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 810 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 811 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 812 |
+
"""
|
| 813 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 814 |
+
|
| 815 |
+
rwkv_outputs = self.rwkv(
|
| 816 |
+
input_ids,
|
| 817 |
+
inputs_embeds=inputs_embeds,
|
| 818 |
+
state=state,
|
| 819 |
+
use_cache=use_cache,
|
| 820 |
+
output_attentions=output_attentions,
|
| 821 |
+
output_hidden_states=output_hidden_states,
|
| 822 |
+
return_dict=return_dict,
|
| 823 |
+
)
|
| 824 |
+
hidden_states = rwkv_outputs[0]
|
| 825 |
+
|
| 826 |
+
logits = self.head(hidden_states)
|
| 827 |
+
|
| 828 |
+
loss = None
|
| 829 |
+
if labels is not None:
|
| 830 |
+
loss = self.loss_function(
|
| 831 |
+
logits,
|
| 832 |
+
labels,
|
| 833 |
+
vocab_size=self.config.vocab_size,
|
| 834 |
+
**kwargs,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
if not return_dict:
|
| 838 |
+
output = (logits,) + rwkv_outputs[1:]
|
| 839 |
+
return ((loss,) + output) if loss is not None else output
|
| 840 |
+
|
| 841 |
+
return RwkvCausalLMOutput(
|
| 842 |
+
loss=loss,
|
| 843 |
+
logits=logits,
|
| 844 |
+
state=rwkv_outputs.state,
|
| 845 |
+
hidden_states=rwkv_outputs.hidden_states,
|
| 846 |
+
attentions=rwkv_outputs.attentions,
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
__all__ = ["RwkvForCausalLM", "RwkvModel", "RwkvPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/sam/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_sam import *
|
| 22 |
+
from .image_processing_sam import *
|
| 23 |
+
from .modeling_sam import *
|
| 24 |
+
from .modeling_tf_sam import *
|
| 25 |
+
from .processing_sam import *
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/sam/configuration_sam.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""SAM model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SamPromptEncoderConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`SamPromptEncoder`]. The [`SamPromptEncoder`]
|
| 27 |
+
module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield
|
| 28 |
+
a similar configuration to that of the SAM-vit-h
|
| 29 |
+
[facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
hidden_size (`int`, *optional*, defaults to 256):
|
| 36 |
+
Dimensionality of the hidden states.
|
| 37 |
+
image_size (`int`, *optional*, defaults to 1024):
|
| 38 |
+
The expected output resolution of the image.
|
| 39 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 40 |
+
The size (resolution) of each patch.
|
| 41 |
+
mask_input_channels (`int`, *optional*, defaults to 16):
|
| 42 |
+
The number of channels to be fed to the `MaskDecoder` module.
|
| 43 |
+
num_point_embeddings (`int`, *optional*, defaults to 4):
|
| 44 |
+
The number of point embeddings to be used.
|
| 45 |
+
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
| 46 |
+
The non-linear activation function in the encoder and pooler.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
base_config_key = "prompt_encoder_config"
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
hidden_size=256,
|
| 54 |
+
image_size=1024,
|
| 55 |
+
patch_size=16,
|
| 56 |
+
mask_input_channels=16,
|
| 57 |
+
num_point_embeddings=4,
|
| 58 |
+
hidden_act="gelu",
|
| 59 |
+
layer_norm_eps=1e-6,
|
| 60 |
+
**kwargs,
|
| 61 |
+
):
|
| 62 |
+
super().__init__(**kwargs)
|
| 63 |
+
self.hidden_size = hidden_size
|
| 64 |
+
self.image_size = image_size
|
| 65 |
+
self.patch_size = patch_size
|
| 66 |
+
self.image_embedding_size = image_size // patch_size
|
| 67 |
+
self.mask_input_channels = mask_input_channels
|
| 68 |
+
self.num_point_embeddings = num_point_embeddings
|
| 69 |
+
self.hidden_act = hidden_act
|
| 70 |
+
self.layer_norm_eps = layer_norm_eps
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class SamMaskDecoderConfig(PretrainedConfig):
|
| 74 |
+
r"""
|
| 75 |
+
This is the configuration class to store the configuration of a [`SamMaskDecoder`]. It is used to instantiate a SAM
|
| 76 |
+
mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults
|
| 77 |
+
will yield a similar configuration to that of the SAM-vit-h
|
| 78 |
+
[facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
|
| 79 |
+
|
| 80 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 81 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
hidden_size (`int`, *optional*, defaults to 256):
|
| 85 |
+
Dimensionality of the hidden states.
|
| 86 |
+
hidden_act (`str`, *optional*, defaults to `"relu"`):
|
| 87 |
+
The non-linear activation function used inside the `SamMaskDecoder` module.
|
| 88 |
+
mlp_dim (`int`, *optional*, defaults to 2048):
|
| 89 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 90 |
+
num_hidden_layers (`int`, *optional*, defaults to 2):
|
| 91 |
+
Number of hidden layers in the Transformer encoder.
|
| 92 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
| 93 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 94 |
+
attention_downsample_rate (`int`, *optional*, defaults to 2):
|
| 95 |
+
The downsampling rate of the attention layer.
|
| 96 |
+
num_multimask_outputs (`int`, *optional*, defaults to 3):
|
| 97 |
+
The number of outputs from the `SamMaskDecoder` module. In the Segment Anything paper, this is set to 3.
|
| 98 |
+
iou_head_depth (`int`, *optional*, defaults to 3):
|
| 99 |
+
The number of layers in the IoU head module.
|
| 100 |
+
iou_head_hidden_dim (`int`, *optional*, defaults to 256):
|
| 101 |
+
The dimensionality of the hidden states in the IoU head module.
|
| 102 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 103 |
+
The epsilon used by the layer normalization layers.
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
base_config_key = "mask_decoder_config"
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
hidden_size=256,
|
| 112 |
+
hidden_act="relu",
|
| 113 |
+
mlp_dim=2048,
|
| 114 |
+
num_hidden_layers=2,
|
| 115 |
+
num_attention_heads=8,
|
| 116 |
+
attention_downsample_rate=2,
|
| 117 |
+
num_multimask_outputs=3,
|
| 118 |
+
iou_head_depth=3,
|
| 119 |
+
iou_head_hidden_dim=256,
|
| 120 |
+
layer_norm_eps=1e-6,
|
| 121 |
+
**kwargs,
|
| 122 |
+
):
|
| 123 |
+
super().__init__(**kwargs)
|
| 124 |
+
self.hidden_size = hidden_size
|
| 125 |
+
self.hidden_act = hidden_act
|
| 126 |
+
self.mlp_dim = mlp_dim
|
| 127 |
+
self.num_hidden_layers = num_hidden_layers
|
| 128 |
+
self.num_attention_heads = num_attention_heads
|
| 129 |
+
self.attention_downsample_rate = attention_downsample_rate
|
| 130 |
+
self.num_multimask_outputs = num_multimask_outputs
|
| 131 |
+
self.iou_head_depth = iou_head_depth
|
| 132 |
+
self.iou_head_hidden_dim = iou_head_hidden_dim
|
| 133 |
+
self.layer_norm_eps = layer_norm_eps
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class SamVisionConfig(PretrainedConfig):
|
| 137 |
+
r"""
|
| 138 |
+
This is the configuration class to store the configuration of a [`SamVisionModel`]. It is used to instantiate a SAM
|
| 139 |
+
vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 140 |
+
defaults will yield a similar configuration to that of the SAM ViT-h
|
| 141 |
+
[facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
|
| 142 |
+
|
| 143 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 144 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 148 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 149 |
+
output_channels (`int`, *optional*, defaults to 256):
|
| 150 |
+
Dimensionality of the output channels in the Patch Encoder.
|
| 151 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 152 |
+
Number of hidden layers in the Transformer encoder.
|
| 153 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 154 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 155 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 156 |
+
Number of channels in the input image.
|
| 157 |
+
image_size (`int`, *optional*, defaults to 1024):
|
| 158 |
+
Expected resolution. Target size of the resized input image.
|
| 159 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 160 |
+
Size of the patches to be extracted from the input image.
|
| 161 |
+
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
| 162 |
+
The non-linear activation function (function or string)
|
| 163 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 164 |
+
The epsilon used by the layer normalization layers.
|
| 165 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 166 |
+
The dropout ratio for the attention probabilities.
|
| 167 |
+
initializer_range (`float`, *optional*, defaults to 1e-10):
|
| 168 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 169 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 170 |
+
Whether to add a bias to query, key, value projections.
|
| 171 |
+
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
| 172 |
+
Ratio of mlp hidden dim to embedding dim.
|
| 173 |
+
use_abs_pos (`bool`, *optional*, defaults to `True`):
|
| 174 |
+
Whether to use absolute position embedding.
|
| 175 |
+
use_rel_pos (`bool`, *optional*, defaults to `True`):
|
| 176 |
+
Whether to use relative position embedding.
|
| 177 |
+
window_size (`int`, *optional*, defaults to 14):
|
| 178 |
+
Window size for relative position.
|
| 179 |
+
global_attn_indexes (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
|
| 180 |
+
The indexes of the global attention layers.
|
| 181 |
+
num_pos_feats (`int`, *optional*, defaults to 128):
|
| 182 |
+
The dimensionality of the position embedding.
|
| 183 |
+
mlp_dim (`int`, *optional*):
|
| 184 |
+
The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
|
| 185 |
+
hidden_size`.
|
| 186 |
+
|
| 187 |
+
Example:
|
| 188 |
+
|
| 189 |
+
```python
|
| 190 |
+
>>> from transformers import (
|
| 191 |
+
... SamVisionConfig,
|
| 192 |
+
... SamVisionModel,
|
| 193 |
+
... )
|
| 194 |
+
|
| 195 |
+
>>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration
|
| 196 |
+
>>> configuration = SamVisionConfig()
|
| 197 |
+
|
| 198 |
+
>>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
|
| 199 |
+
>>> model = SamVisionModel(configuration)
|
| 200 |
+
|
| 201 |
+
>>> # Accessing the model configuration
|
| 202 |
+
>>> configuration = model.config
|
| 203 |
+
```"""
|
| 204 |
+
|
| 205 |
+
base_config_key = "vision_config"
|
| 206 |
+
model_type = "sam_vision_model"
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
hidden_size=768,
|
| 211 |
+
output_channels=256,
|
| 212 |
+
num_hidden_layers=12,
|
| 213 |
+
num_attention_heads=12,
|
| 214 |
+
num_channels=3,
|
| 215 |
+
image_size=1024,
|
| 216 |
+
patch_size=16,
|
| 217 |
+
hidden_act="gelu",
|
| 218 |
+
layer_norm_eps=1e-06,
|
| 219 |
+
attention_dropout=0.0,
|
| 220 |
+
initializer_range=1e-10,
|
| 221 |
+
qkv_bias=True,
|
| 222 |
+
mlp_ratio=4.0,
|
| 223 |
+
use_abs_pos=True,
|
| 224 |
+
use_rel_pos=True,
|
| 225 |
+
window_size=14,
|
| 226 |
+
global_attn_indexes=[2, 5, 8, 11],
|
| 227 |
+
num_pos_feats=128,
|
| 228 |
+
mlp_dim=None,
|
| 229 |
+
**kwargs,
|
| 230 |
+
):
|
| 231 |
+
super().__init__(**kwargs)
|
| 232 |
+
|
| 233 |
+
self.hidden_size = hidden_size
|
| 234 |
+
self.output_channels = output_channels
|
| 235 |
+
self.num_hidden_layers = num_hidden_layers
|
| 236 |
+
self.num_attention_heads = num_attention_heads
|
| 237 |
+
self.num_channels = num_channels
|
| 238 |
+
self.image_size = image_size
|
| 239 |
+
self.patch_size = patch_size
|
| 240 |
+
self.hidden_act = hidden_act
|
| 241 |
+
self.layer_norm_eps = layer_norm_eps
|
| 242 |
+
self.attention_dropout = attention_dropout
|
| 243 |
+
self.initializer_range = initializer_range
|
| 244 |
+
self.qkv_bias = qkv_bias
|
| 245 |
+
self.mlp_ratio = mlp_ratio
|
| 246 |
+
self.use_abs_pos = use_abs_pos
|
| 247 |
+
self.use_rel_pos = use_rel_pos
|
| 248 |
+
self.window_size = window_size
|
| 249 |
+
self.global_attn_indexes = global_attn_indexes
|
| 250 |
+
self.num_pos_feats = num_pos_feats
|
| 251 |
+
self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class SamConfig(PretrainedConfig):
|
| 255 |
+
r"""
|
| 256 |
+
[`SamConfig`] is the configuration class to store the configuration of a [`SamModel`]. It is used to instantiate a
|
| 257 |
+
SAM model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder
|
| 258 |
+
configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
| 259 |
+
SAM-ViT-H [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
|
| 260 |
+
|
| 261 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 262 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
vision_config (Union[`dict`, `SamVisionConfig`], *optional*):
|
| 266 |
+
Dictionary of configuration options used to initialize [`SamVisionConfig`].
|
| 267 |
+
prompt_encoder_config (Union[`dict`, `SamPromptEncoderConfig`], *optional*):
|
| 268 |
+
Dictionary of configuration options used to initialize [`SamPromptEncoderConfig`].
|
| 269 |
+
mask_decoder_config (Union[`dict`, `SamMaskDecoderConfig`], *optional*):
|
| 270 |
+
Dictionary of configuration options used to initialize [`SamMaskDecoderConfig`].
|
| 271 |
+
|
| 272 |
+
kwargs (*optional*):
|
| 273 |
+
Dictionary of keyword arguments.
|
| 274 |
+
|
| 275 |
+
Example:
|
| 276 |
+
|
| 277 |
+
```python
|
| 278 |
+
>>> from transformers import (
|
| 279 |
+
... SamVisionConfig,
|
| 280 |
+
... SamPromptEncoderConfig,
|
| 281 |
+
... SamMaskDecoderConfig,
|
| 282 |
+
... SamModel,
|
| 283 |
+
... )
|
| 284 |
+
|
| 285 |
+
>>> # Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration
|
| 286 |
+
>>> configuration = SamConfig()
|
| 287 |
+
|
| 288 |
+
>>> # Initializing a SamModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
|
| 289 |
+
>>> model = SamModel(configuration)
|
| 290 |
+
|
| 291 |
+
>>> # Accessing the model configuration
|
| 292 |
+
>>> configuration = model.config
|
| 293 |
+
|
| 294 |
+
>>> # We can also initialize a SamConfig from a SamVisionConfig, SamPromptEncoderConfig, and SamMaskDecoderConfig
|
| 295 |
+
|
| 296 |
+
>>> # Initializing SAM vision, SAM Q-Former and language model configurations
|
| 297 |
+
>>> vision_config = SamVisionConfig()
|
| 298 |
+
>>> prompt_encoder_config = SamPromptEncoderConfig()
|
| 299 |
+
>>> mask_decoder_config = SamMaskDecoderConfig()
|
| 300 |
+
|
| 301 |
+
>>> config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
|
| 302 |
+
```"""
|
| 303 |
+
|
| 304 |
+
model_type = "sam"
|
| 305 |
+
sub_configs = {
|
| 306 |
+
"prompt_encoder_config": SamPromptEncoderConfig,
|
| 307 |
+
"mask_decoder_config": SamMaskDecoderConfig,
|
| 308 |
+
"vision_config": SamVisionConfig,
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
vision_config=None,
|
| 314 |
+
prompt_encoder_config=None,
|
| 315 |
+
mask_decoder_config=None,
|
| 316 |
+
initializer_range=0.02,
|
| 317 |
+
**kwargs,
|
| 318 |
+
):
|
| 319 |
+
super().__init__(**kwargs)
|
| 320 |
+
vision_config = vision_config if vision_config is not None else {}
|
| 321 |
+
prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
|
| 322 |
+
mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
|
| 323 |
+
|
| 324 |
+
if isinstance(vision_config, SamVisionConfig):
|
| 325 |
+
vision_config = vision_config.to_dict()
|
| 326 |
+
if isinstance(prompt_encoder_config, SamPromptEncoderConfig):
|
| 327 |
+
prompt_encoder_config = prompt_encoder_config.to_dict()
|
| 328 |
+
if isinstance(mask_decoder_config, SamMaskDecoderConfig):
|
| 329 |
+
mask_decoder_config = mask_decoder_config.to_dict()
|
| 330 |
+
|
| 331 |
+
self.vision_config = SamVisionConfig(**vision_config)
|
| 332 |
+
self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config)
|
| 333 |
+
self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config)
|
| 334 |
+
self.initializer_range = initializer_range
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
__all__ = ["SamConfig", "SamMaskDecoderConfig", "SamPromptEncoderConfig", "SamVisionConfig"]
|
docs/transformers/build/lib/transformers/models/sam/image_processing_sam.py
ADDED
|
@@ -0,0 +1,1494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for SAM."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from itertools import product
|
| 20 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 25 |
+
from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
|
| 26 |
+
from ...image_utils import (
|
| 27 |
+
IMAGENET_DEFAULT_MEAN,
|
| 28 |
+
IMAGENET_DEFAULT_STD,
|
| 29 |
+
ChannelDimension,
|
| 30 |
+
ImageInput,
|
| 31 |
+
PILImageResampling,
|
| 32 |
+
get_image_size,
|
| 33 |
+
infer_channel_dimension_format,
|
| 34 |
+
is_scaled_image,
|
| 35 |
+
make_list_of_images,
|
| 36 |
+
to_numpy_array,
|
| 37 |
+
valid_images,
|
| 38 |
+
validate_preprocess_arguments,
|
| 39 |
+
)
|
| 40 |
+
from ...utils import (
|
| 41 |
+
TensorType,
|
| 42 |
+
filter_out_non_signature_kwargs,
|
| 43 |
+
is_tf_available,
|
| 44 |
+
is_torch_available,
|
| 45 |
+
is_torchvision_available,
|
| 46 |
+
logging,
|
| 47 |
+
requires_backends,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if is_torch_available():
|
| 52 |
+
import torch
|
| 53 |
+
import torch.nn.functional as F
|
| 54 |
+
|
| 55 |
+
if is_torchvision_available():
|
| 56 |
+
from torchvision.ops.boxes import batched_nms
|
| 57 |
+
|
| 58 |
+
if is_tf_available():
|
| 59 |
+
import tensorflow as tf
|
| 60 |
+
from tensorflow.experimental import numpy as tnp
|
| 61 |
+
|
| 62 |
+
from ...tf_utils import flatten, shape_list
|
| 63 |
+
|
| 64 |
+
logger = logging.get_logger(__name__)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class SamImageProcessor(BaseImageProcessor):
|
| 68 |
+
r"""
|
| 69 |
+
Constructs a SAM image processor.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 73 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
| 74 |
+
`do_resize` parameter in the `preprocess` method.
|
| 75 |
+
size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
|
| 76 |
+
Size of the output image after resizing. Resizes the longest edge of the image to match
|
| 77 |
+
`size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `size` parameter in the
|
| 78 |
+
`preprocess` method.
|
| 79 |
+
mask_size (`dict`, *optional*, defaults to `{"longest_edge": 256}`):
|
| 80 |
+
Size of the output segmentation map after resizing. Resizes the longest edge of the image to match
|
| 81 |
+
`size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `mask_size` parameter
|
| 82 |
+
in the `preprocess` method.
|
| 83 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 84 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 85 |
+
`preprocess` method.
|
| 86 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 87 |
+
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
| 88 |
+
`do_rescale` parameter in the `preprocess` method.
|
| 89 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 90 |
+
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
| 91 |
+
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
| 92 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 93 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 94 |
+
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
| 95 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
|
| 96 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 97 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
| 98 |
+
overridden by the `image_mean` parameter in the `preprocess` method.
|
| 99 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
|
| 100 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 101 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 102 |
+
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 103 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 104 |
+
Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
|
| 105 |
+
`preprocess` method.
|
| 106 |
+
pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
|
| 107 |
+
Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
|
| 108 |
+
method.
|
| 109 |
+
mask_pad_size (`dict`, *optional*, defaults to `{"height": 256, "width": 256}`):
|
| 110 |
+
Size of the output segmentation map after padding. Can be overridden by the `mask_pad_size` parameter in
|
| 111 |
+
the `preprocess` method.
|
| 112 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 113 |
+
Whether to convert the image to RGB.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
model_input_names = ["pixel_values"]
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
do_resize: bool = True,
|
| 121 |
+
size: Dict[str, int] = None,
|
| 122 |
+
mask_size: Dict[str, int] = None,
|
| 123 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 124 |
+
do_rescale: bool = True,
|
| 125 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 126 |
+
do_normalize: bool = True,
|
| 127 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 128 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 129 |
+
do_pad: bool = True,
|
| 130 |
+
pad_size: Optional[int] = None,
|
| 131 |
+
mask_pad_size: Optional[int] = None,
|
| 132 |
+
do_convert_rgb: bool = True,
|
| 133 |
+
**kwargs,
|
| 134 |
+
) -> None:
|
| 135 |
+
super().__init__(**kwargs)
|
| 136 |
+
size = size if size is not None else {"longest_edge": 1024}
|
| 137 |
+
size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
|
| 138 |
+
|
| 139 |
+
pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024}
|
| 140 |
+
pad_size = get_size_dict(pad_size, default_to_square=True)
|
| 141 |
+
|
| 142 |
+
mask_size = mask_size if mask_size is not None else {"longest_edge": 256}
|
| 143 |
+
mask_size = (
|
| 144 |
+
get_size_dict(max_size=mask_size, default_to_square=False)
|
| 145 |
+
if not isinstance(mask_size, dict)
|
| 146 |
+
else mask_size
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 256, "width": 256}
|
| 150 |
+
mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
|
| 151 |
+
|
| 152 |
+
self.do_resize = do_resize
|
| 153 |
+
self.size = size
|
| 154 |
+
self.mask_size = mask_size
|
| 155 |
+
self.resample = resample
|
| 156 |
+
self.do_rescale = do_rescale
|
| 157 |
+
self.rescale_factor = rescale_factor
|
| 158 |
+
self.do_normalize = do_normalize
|
| 159 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 160 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 161 |
+
self.do_pad = do_pad
|
| 162 |
+
self.pad_size = pad_size
|
| 163 |
+
self.mask_pad_size = mask_pad_size
|
| 164 |
+
self.do_convert_rgb = do_convert_rgb
|
| 165 |
+
|
| 166 |
+
def pad_image(
|
| 167 |
+
self,
|
| 168 |
+
image: np.ndarray,
|
| 169 |
+
pad_size: Dict[str, int],
|
| 170 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 171 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 172 |
+
**kwargs,
|
| 173 |
+
) -> np.ndarray:
|
| 174 |
+
"""
|
| 175 |
+
Pad an image to `(pad_size["height"], pad_size["width"])` with zeros to the right and bottom.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
image (`np.ndarray`):
|
| 179 |
+
Image to pad.
|
| 180 |
+
pad_size (`Dict[str, int]`):
|
| 181 |
+
Size of the output image after padding.
|
| 182 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 183 |
+
The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
|
| 184 |
+
`data_format` of the `image` will be used.
|
| 185 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 186 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 187 |
+
"""
|
| 188 |
+
output_height, output_width = pad_size["height"], pad_size["width"]
|
| 189 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 190 |
+
|
| 191 |
+
pad_width = output_width - input_width
|
| 192 |
+
pad_height = output_height - input_height
|
| 193 |
+
|
| 194 |
+
padded_image = pad(
|
| 195 |
+
image,
|
| 196 |
+
((0, pad_height), (0, pad_width)),
|
| 197 |
+
data_format=data_format,
|
| 198 |
+
input_data_format=input_data_format,
|
| 199 |
+
**kwargs,
|
| 200 |
+
)
|
| 201 |
+
return padded_image
|
| 202 |
+
|
| 203 |
+
def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
|
| 204 |
+
"""
|
| 205 |
+
Compute the output size given input size and target long side length.
|
| 206 |
+
"""
|
| 207 |
+
oldh, oldw = old_shape
|
| 208 |
+
scale = longest_edge * 1.0 / max(oldh, oldw)
|
| 209 |
+
newh, neww = oldh * scale, oldw * scale
|
| 210 |
+
newh = int(newh + 0.5)
|
| 211 |
+
neww = int(neww + 0.5)
|
| 212 |
+
return (newh, neww)
|
| 213 |
+
|
| 214 |
+
def resize(
|
| 215 |
+
self,
|
| 216 |
+
image: np.ndarray,
|
| 217 |
+
size: Dict[str, int],
|
| 218 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 219 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 220 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 221 |
+
**kwargs,
|
| 222 |
+
) -> np.ndarray:
|
| 223 |
+
"""
|
| 224 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
image (`np.ndarray`):
|
| 228 |
+
Image to resize.
|
| 229 |
+
size (`Dict[str, int]`):
|
| 230 |
+
Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
|
| 231 |
+
edge of the image will be resized to the specified size, while the other edge will be resized to
|
| 232 |
+
maintain the aspect ratio.
|
| 233 |
+
resample:
|
| 234 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 235 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 236 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 237 |
+
image is used. Can be one of:
|
| 238 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 239 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 240 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 241 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 242 |
+
from the input image. Can be one of:
|
| 243 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 244 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
`np.ndarray`: The resized image.
|
| 248 |
+
"""
|
| 249 |
+
size = get_size_dict(size)
|
| 250 |
+
if "longest_edge" not in size:
|
| 251 |
+
raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
|
| 252 |
+
input_size = get_image_size(image, channel_dim=input_data_format)
|
| 253 |
+
output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
|
| 254 |
+
return resize(
|
| 255 |
+
image,
|
| 256 |
+
size=(output_height, output_width),
|
| 257 |
+
resample=resample,
|
| 258 |
+
data_format=data_format,
|
| 259 |
+
input_data_format=input_data_format,
|
| 260 |
+
**kwargs,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def _preprocess(
|
| 264 |
+
self,
|
| 265 |
+
image: ImageInput,
|
| 266 |
+
do_resize: bool,
|
| 267 |
+
do_rescale: bool,
|
| 268 |
+
do_normalize: bool,
|
| 269 |
+
size: Optional[Dict[str, int]] = None,
|
| 270 |
+
resample: PILImageResampling = None,
|
| 271 |
+
rescale_factor: Optional[float] = None,
|
| 272 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 273 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 274 |
+
do_pad: Optional[bool] = None,
|
| 275 |
+
pad_size: Optional[Dict[str, int]] = None,
|
| 276 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 277 |
+
):
|
| 278 |
+
if do_resize:
|
| 279 |
+
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 280 |
+
reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
|
| 281 |
+
|
| 282 |
+
if do_rescale:
|
| 283 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 284 |
+
|
| 285 |
+
if do_normalize:
|
| 286 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 287 |
+
|
| 288 |
+
if do_pad:
|
| 289 |
+
image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
|
| 290 |
+
|
| 291 |
+
return image, reshaped_input_size
|
| 292 |
+
|
| 293 |
+
def _preprocess_image(
|
| 294 |
+
self,
|
| 295 |
+
image: ImageInput,
|
| 296 |
+
do_resize: Optional[bool] = None,
|
| 297 |
+
size: Dict[str, int] = None,
|
| 298 |
+
resample: PILImageResampling = None,
|
| 299 |
+
do_rescale: Optional[bool] = None,
|
| 300 |
+
rescale_factor: Optional[float] = None,
|
| 301 |
+
do_normalize: Optional[bool] = None,
|
| 302 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 303 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 304 |
+
do_pad: Optional[bool] = None,
|
| 305 |
+
pad_size: Optional[Dict[str, int]] = None,
|
| 306 |
+
do_convert_rgb: Optional[bool] = None,
|
| 307 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 308 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 309 |
+
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
|
| 310 |
+
# PIL RGBA images are converted to RGB
|
| 311 |
+
if do_convert_rgb:
|
| 312 |
+
image = convert_to_rgb(image)
|
| 313 |
+
|
| 314 |
+
# All transformations expect numpy arrays.
|
| 315 |
+
image = to_numpy_array(image)
|
| 316 |
+
|
| 317 |
+
if do_rescale and is_scaled_image(image):
|
| 318 |
+
logger.warning_once(
|
| 319 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 320 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if input_data_format is None:
|
| 324 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 325 |
+
|
| 326 |
+
original_size = get_image_size(image, channel_dim=input_data_format)
|
| 327 |
+
|
| 328 |
+
image, reshaped_input_size = self._preprocess(
|
| 329 |
+
image=image,
|
| 330 |
+
do_resize=do_resize,
|
| 331 |
+
size=size,
|
| 332 |
+
resample=resample,
|
| 333 |
+
do_rescale=do_rescale,
|
| 334 |
+
rescale_factor=rescale_factor,
|
| 335 |
+
do_normalize=do_normalize,
|
| 336 |
+
image_mean=image_mean,
|
| 337 |
+
image_std=image_std,
|
| 338 |
+
do_pad=do_pad,
|
| 339 |
+
pad_size=pad_size,
|
| 340 |
+
input_data_format=input_data_format,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if data_format is not None:
|
| 344 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 345 |
+
|
| 346 |
+
return image, original_size, reshaped_input_size
|
| 347 |
+
|
| 348 |
+
def _preprocess_mask(
|
| 349 |
+
self,
|
| 350 |
+
segmentation_map: ImageInput,
|
| 351 |
+
do_resize: Optional[bool] = None,
|
| 352 |
+
mask_size: Dict[str, int] = None,
|
| 353 |
+
do_pad: Optional[bool] = None,
|
| 354 |
+
mask_pad_size: Optional[Dict[str, int]] = None,
|
| 355 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 356 |
+
) -> np.ndarray:
|
| 357 |
+
segmentation_map = to_numpy_array(segmentation_map)
|
| 358 |
+
|
| 359 |
+
# Add channel dimension if missing - needed for certain transformations
|
| 360 |
+
if segmentation_map.ndim == 2:
|
| 361 |
+
added_channel_dim = True
|
| 362 |
+
segmentation_map = segmentation_map[None, ...]
|
| 363 |
+
input_data_format = ChannelDimension.FIRST
|
| 364 |
+
else:
|
| 365 |
+
added_channel_dim = False
|
| 366 |
+
if input_data_format is None:
|
| 367 |
+
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
| 368 |
+
|
| 369 |
+
original_size = get_image_size(segmentation_map, channel_dim=input_data_format)
|
| 370 |
+
|
| 371 |
+
segmentation_map, _ = self._preprocess(
|
| 372 |
+
image=segmentation_map,
|
| 373 |
+
do_resize=do_resize,
|
| 374 |
+
size=mask_size,
|
| 375 |
+
resample=PILImageResampling.NEAREST,
|
| 376 |
+
do_rescale=False,
|
| 377 |
+
do_normalize=False,
|
| 378 |
+
do_pad=do_pad,
|
| 379 |
+
pad_size=mask_pad_size,
|
| 380 |
+
input_data_format=input_data_format,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Remove extra channel dimension if added for processing
|
| 384 |
+
if added_channel_dim:
|
| 385 |
+
segmentation_map = segmentation_map.squeeze(0)
|
| 386 |
+
segmentation_map = segmentation_map.astype(np.int64)
|
| 387 |
+
|
| 388 |
+
return segmentation_map, original_size
|
| 389 |
+
|
| 390 |
+
@filter_out_non_signature_kwargs()
|
| 391 |
+
def preprocess(
|
| 392 |
+
self,
|
| 393 |
+
images: ImageInput,
|
| 394 |
+
segmentation_maps: Optional[ImageInput] = None,
|
| 395 |
+
do_resize: Optional[bool] = None,
|
| 396 |
+
size: Optional[Dict[str, int]] = None,
|
| 397 |
+
mask_size: Optional[Dict[str, int]] = None,
|
| 398 |
+
resample: Optional["PILImageResampling"] = None,
|
| 399 |
+
do_rescale: Optional[bool] = None,
|
| 400 |
+
rescale_factor: Optional[Union[int, float]] = None,
|
| 401 |
+
do_normalize: Optional[bool] = None,
|
| 402 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 403 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 404 |
+
do_pad: Optional[bool] = None,
|
| 405 |
+
pad_size: Optional[Dict[str, int]] = None,
|
| 406 |
+
mask_pad_size: Optional[Dict[str, int]] = None,
|
| 407 |
+
do_convert_rgb: Optional[bool] = None,
|
| 408 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 409 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
| 410 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 411 |
+
):
|
| 412 |
+
"""
|
| 413 |
+
Preprocess an image or batch of images.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
images (`ImageInput`):
|
| 417 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 418 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 419 |
+
segmentation_maps (`ImageInput`, *optional*):
|
| 420 |
+
Segmentation map to preprocess.
|
| 421 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 422 |
+
Whether to resize the image.
|
| 423 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 424 |
+
Controls the size of the image after `resize`. The longest edge of the image is resized to
|
| 425 |
+
`size["longest_edge"]` whilst preserving the aspect ratio.
|
| 426 |
+
mask_size (`Dict[str, int]`, *optional*, defaults to `self.mask_size`):
|
| 427 |
+
Controls the size of the segmentation map after `resize`. The longest edge of the image is resized to
|
| 428 |
+
`size["longest_edge"]` whilst preserving the aspect ratio.
|
| 429 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 430 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 431 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 432 |
+
Whether to rescale the image pixel values by rescaling factor.
|
| 433 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
|
| 434 |
+
Rescale factor to apply to the image pixel values.
|
| 435 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 436 |
+
Whether to normalize the image.
|
| 437 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 438 |
+
Image mean to normalize the image by if `do_normalize` is set to `True`.
|
| 439 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 440 |
+
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
|
| 441 |
+
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
| 442 |
+
Whether to pad the image.
|
| 443 |
+
pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`):
|
| 444 |
+
Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
|
| 445 |
+
`pad_size["width"]` if `do_pad` is set to `True`.
|
| 446 |
+
mask_pad_size (`Dict[str, int]`, *optional*, defaults to `self.mask_pad_size`):
|
| 447 |
+
Controls the size of the padding applied to the segmentation map. The image is padded to
|
| 448 |
+
`mask_pad_size["height"]` and `mask_pad_size["width"]` if `do_pad` is set to `True`.
|
| 449 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 450 |
+
Whether to convert the image to RGB.
|
| 451 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 452 |
+
The type of tensors to return. Can be one of:
|
| 453 |
+
- Unset: Return a list of `np.ndarray`.
|
| 454 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 455 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 456 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 457 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 458 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 459 |
+
The channel dimension format for the output image. Can be one of:
|
| 460 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 461 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 462 |
+
- Unset: Use the channel dimension format of the input image.
|
| 463 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 464 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 465 |
+
from the input image. Can be one of:
|
| 466 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 467 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 468 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 469 |
+
"""
|
| 470 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 471 |
+
size = size if size is not None else self.size
|
| 472 |
+
size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
|
| 473 |
+
mask_size = mask_size if mask_size is not None else self.mask_size
|
| 474 |
+
mask_size = (
|
| 475 |
+
get_size_dict(max_size=mask_size, default_to_square=False)
|
| 476 |
+
if not isinstance(mask_size, dict)
|
| 477 |
+
else mask_size
|
| 478 |
+
)
|
| 479 |
+
resample = resample if resample is not None else self.resample
|
| 480 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 481 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 482 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 483 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 484 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 485 |
+
do_pad = do_pad if do_pad is not None else self.do_pad
|
| 486 |
+
pad_size = pad_size if pad_size is not None else self.pad_size
|
| 487 |
+
pad_size = get_size_dict(pad_size, default_to_square=True)
|
| 488 |
+
mask_pad_size = mask_pad_size if mask_pad_size is not None else self.mask_pad_size
|
| 489 |
+
mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
|
| 490 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 491 |
+
|
| 492 |
+
images = make_list_of_images(images)
|
| 493 |
+
|
| 494 |
+
if not valid_images(images):
|
| 495 |
+
raise ValueError(
|
| 496 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 497 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if segmentation_maps is not None:
|
| 501 |
+
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
| 502 |
+
|
| 503 |
+
if not valid_images(segmentation_maps):
|
| 504 |
+
raise ValueError(
|
| 505 |
+
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 506 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 507 |
+
)
|
| 508 |
+
validate_preprocess_arguments(
|
| 509 |
+
do_rescale=do_rescale,
|
| 510 |
+
rescale_factor=rescale_factor,
|
| 511 |
+
do_normalize=do_normalize,
|
| 512 |
+
image_mean=image_mean,
|
| 513 |
+
image_std=image_std,
|
| 514 |
+
do_pad=do_pad,
|
| 515 |
+
size_divisibility=pad_size, # Here _preprocess needs do_pad and pad_size.
|
| 516 |
+
do_resize=do_resize,
|
| 517 |
+
size=size,
|
| 518 |
+
resample=resample,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
images, original_sizes, reshaped_input_sizes = zip(
|
| 522 |
+
*(
|
| 523 |
+
self._preprocess_image(
|
| 524 |
+
image=img,
|
| 525 |
+
do_resize=do_resize,
|
| 526 |
+
size=size,
|
| 527 |
+
resample=resample,
|
| 528 |
+
do_rescale=do_rescale,
|
| 529 |
+
rescale_factor=rescale_factor,
|
| 530 |
+
do_normalize=do_normalize,
|
| 531 |
+
image_mean=image_mean,
|
| 532 |
+
image_std=image_std,
|
| 533 |
+
do_pad=do_pad,
|
| 534 |
+
pad_size=pad_size,
|
| 535 |
+
do_convert_rgb=do_convert_rgb,
|
| 536 |
+
data_format=data_format,
|
| 537 |
+
input_data_format=input_data_format,
|
| 538 |
+
)
|
| 539 |
+
for img in images
|
| 540 |
+
)
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
data = {
|
| 544 |
+
"pixel_values": images,
|
| 545 |
+
"original_sizes": original_sizes,
|
| 546 |
+
"reshaped_input_sizes": reshaped_input_sizes,
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
if segmentation_maps is not None:
|
| 550 |
+
segmentation_maps, original_mask_sizes = zip(
|
| 551 |
+
*(
|
| 552 |
+
self._preprocess_mask(
|
| 553 |
+
segmentation_map=mask,
|
| 554 |
+
do_resize=do_resize,
|
| 555 |
+
mask_size=mask_size,
|
| 556 |
+
do_pad=do_pad,
|
| 557 |
+
mask_pad_size=mask_pad_size,
|
| 558 |
+
input_data_format=input_data_format,
|
| 559 |
+
)
|
| 560 |
+
for mask in segmentation_maps
|
| 561 |
+
)
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# masks should start out the same size as input images
|
| 565 |
+
assert all(
|
| 566 |
+
original_im_size == original_mask_size
|
| 567 |
+
for original_im_size, original_mask_size in zip(original_sizes, original_mask_sizes)
|
| 568 |
+
), "Segmentation maps should be the same size as input images."
|
| 569 |
+
|
| 570 |
+
data["labels"] = segmentation_maps
|
| 571 |
+
|
| 572 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 573 |
+
|
| 574 |
+
def post_process_masks(
|
| 575 |
+
self,
|
| 576 |
+
masks,
|
| 577 |
+
original_sizes,
|
| 578 |
+
reshaped_input_sizes,
|
| 579 |
+
mask_threshold=0.0,
|
| 580 |
+
binarize=True,
|
| 581 |
+
pad_size=None,
|
| 582 |
+
return_tensors="pt",
|
| 583 |
+
):
|
| 584 |
+
"""
|
| 585 |
+
Remove padding and upscale masks to the original image size.
|
| 586 |
+
|
| 587 |
+
Args:
|
| 588 |
+
masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`):
|
| 589 |
+
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
| 590 |
+
original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
|
| 591 |
+
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
| 592 |
+
width) format.
|
| 593 |
+
reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
|
| 594 |
+
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
| 595 |
+
mask_threshold (`float`, *optional*, defaults to 0.0):
|
| 596 |
+
The threshold to use for binarizing the masks.
|
| 597 |
+
binarize (`bool`, *optional*, defaults to `True`):
|
| 598 |
+
Whether to binarize the masks.
|
| 599 |
+
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
| 600 |
+
The target size the images were padded to before being passed to the model. If None, the target size is
|
| 601 |
+
assumed to be the processor's `pad_size`.
|
| 602 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 603 |
+
If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
|
| 604 |
+
Returns:
|
| 605 |
+
(`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
|
| 606 |
+
(height, width) is given by original_size.
|
| 607 |
+
"""
|
| 608 |
+
if return_tensors == "pt":
|
| 609 |
+
return self._post_process_masks_pt(
|
| 610 |
+
masks=masks,
|
| 611 |
+
original_sizes=original_sizes,
|
| 612 |
+
reshaped_input_sizes=reshaped_input_sizes,
|
| 613 |
+
mask_threshold=mask_threshold,
|
| 614 |
+
binarize=binarize,
|
| 615 |
+
pad_size=pad_size,
|
| 616 |
+
)
|
| 617 |
+
elif return_tensors == "tf":
|
| 618 |
+
return self._post_process_masks_tf(
|
| 619 |
+
masks=masks,
|
| 620 |
+
original_sizes=original_sizes,
|
| 621 |
+
reshaped_input_sizes=reshaped_input_sizes,
|
| 622 |
+
mask_threshold=mask_threshold,
|
| 623 |
+
binarize=binarize,
|
| 624 |
+
pad_size=pad_size,
|
| 625 |
+
)
|
| 626 |
+
else:
|
| 627 |
+
raise ValueError("return_tensors must be either 'pt' or 'tf'")
|
| 628 |
+
|
| 629 |
+
def _post_process_masks_pt(
|
| 630 |
+
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
| 631 |
+
):
|
| 632 |
+
"""
|
| 633 |
+
Remove padding and upscale masks to the original image size.
|
| 634 |
+
|
| 635 |
+
Args:
|
| 636 |
+
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
|
| 637 |
+
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
| 638 |
+
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
| 639 |
+
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
| 640 |
+
width) format.
|
| 641 |
+
reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
| 642 |
+
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
| 643 |
+
mask_threshold (`float`, *optional*, defaults to 0.0):
|
| 644 |
+
The threshold to use for binarizing the masks.
|
| 645 |
+
binarize (`bool`, *optional*, defaults to `True`):
|
| 646 |
+
Whether to binarize the masks.
|
| 647 |
+
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
| 648 |
+
The target size the images were padded to before being passed to the model. If None, the target size is
|
| 649 |
+
assumed to be the processor's `pad_size`.
|
| 650 |
+
Returns:
|
| 651 |
+
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
|
| 652 |
+
is given by original_size.
|
| 653 |
+
"""
|
| 654 |
+
requires_backends(self, ["torch"])
|
| 655 |
+
pad_size = self.pad_size if pad_size is None else pad_size
|
| 656 |
+
target_image_size = (pad_size["height"], pad_size["width"])
|
| 657 |
+
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
|
| 658 |
+
original_sizes = original_sizes.tolist()
|
| 659 |
+
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
|
| 660 |
+
reshaped_input_sizes = reshaped_input_sizes.tolist()
|
| 661 |
+
output_masks = []
|
| 662 |
+
for i, original_size in enumerate(original_sizes):
|
| 663 |
+
if isinstance(masks[i], np.ndarray):
|
| 664 |
+
masks[i] = torch.from_numpy(masks[i])
|
| 665 |
+
elif not isinstance(masks[i], torch.Tensor):
|
| 666 |
+
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
|
| 667 |
+
interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
|
| 668 |
+
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
|
| 669 |
+
interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
|
| 670 |
+
if binarize:
|
| 671 |
+
interpolated_mask = interpolated_mask > mask_threshold
|
| 672 |
+
output_masks.append(interpolated_mask)
|
| 673 |
+
|
| 674 |
+
return output_masks
|
| 675 |
+
|
| 676 |
+
def _post_process_masks_tf(
|
| 677 |
+
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
| 678 |
+
):
|
| 679 |
+
"""
|
| 680 |
+
Remove padding and upscale masks to the original image size.
|
| 681 |
+
|
| 682 |
+
Args:
|
| 683 |
+
masks (`tf.Tensor`):
|
| 684 |
+
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
| 685 |
+
original_sizes (`tf.Tensor`):
|
| 686 |
+
The original size of the images before resizing for input to the model, in (height, width) format.
|
| 687 |
+
reshaped_input_sizes (`tf.Tensor`):
|
| 688 |
+
The size of the image input to the model, in (height, width) format. Used to remove padding.
|
| 689 |
+
mask_threshold (`float`, *optional*, defaults to 0.0):
|
| 690 |
+
The threshold to use for binarizing the masks.
|
| 691 |
+
binarize (`bool`, *optional*, defaults to `True`):
|
| 692 |
+
Whether to binarize the masks.
|
| 693 |
+
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
| 694 |
+
The target size the images were padded to before being passed to the model. If None, the target size is
|
| 695 |
+
assumed to be the processor's `pad_size`.
|
| 696 |
+
Returns:
|
| 697 |
+
(`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
|
| 698 |
+
given by original_size.
|
| 699 |
+
"""
|
| 700 |
+
requires_backends(self, ["tf"])
|
| 701 |
+
pad_size = self.pad_size if pad_size is None else pad_size
|
| 702 |
+
target_image_size = (pad_size["height"], pad_size["width"])
|
| 703 |
+
|
| 704 |
+
output_masks = []
|
| 705 |
+
for i, original_size in enumerate(original_sizes):
|
| 706 |
+
# tf.image expects NHWC, we transpose the NCHW inputs for it
|
| 707 |
+
mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
|
| 708 |
+
interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
|
| 709 |
+
interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
|
| 710 |
+
interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
|
| 711 |
+
if binarize:
|
| 712 |
+
interpolated_mask = interpolated_mask > mask_threshold
|
| 713 |
+
# And then we transpose them back at the end
|
| 714 |
+
output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
|
| 715 |
+
|
| 716 |
+
return output_masks
|
| 717 |
+
|
| 718 |
+
def post_process_for_mask_generation(
|
| 719 |
+
self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"
|
| 720 |
+
):
|
| 721 |
+
"""
|
| 722 |
+
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
|
| 723 |
+
|
| 724 |
+
Args:
|
| 725 |
+
all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`):
|
| 726 |
+
List of all predicted segmentation masks
|
| 727 |
+
all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`):
|
| 728 |
+
List of all predicted iou scores
|
| 729 |
+
all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`):
|
| 730 |
+
List of all bounding boxes of the predicted masks
|
| 731 |
+
crops_nms_thresh (`float`):
|
| 732 |
+
Threshold for NMS (Non Maximum Suppression) algorithm.
|
| 733 |
+
return_tensors (`str`, *optional*, defaults to `pt`):
|
| 734 |
+
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
| 735 |
+
"""
|
| 736 |
+
if return_tensors == "pt":
|
| 737 |
+
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
| 738 |
+
elif return_tensors == "tf":
|
| 739 |
+
return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
| 740 |
+
|
| 741 |
+
def generate_crop_boxes(
|
| 742 |
+
self,
|
| 743 |
+
image,
|
| 744 |
+
target_size,
|
| 745 |
+
crop_n_layers: int = 0,
|
| 746 |
+
overlap_ratio: float = 512 / 1500,
|
| 747 |
+
points_per_crop: Optional[int] = 32,
|
| 748 |
+
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
| 749 |
+
device: Optional["torch.device"] = None,
|
| 750 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 751 |
+
return_tensors: str = "pt",
|
| 752 |
+
):
|
| 753 |
+
"""
|
| 754 |
+
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
| 755 |
+
|
| 756 |
+
Args:
|
| 757 |
+
image (`np.array`):
|
| 758 |
+
Input original image
|
| 759 |
+
target_size (`int`):
|
| 760 |
+
Target size of the resized image
|
| 761 |
+
crop_n_layers (`int`, *optional*, defaults to 0):
|
| 762 |
+
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
|
| 763 |
+
each layer has 2**i_layer number of image crops.
|
| 764 |
+
overlap_ratio (`float`, *optional*, defaults to 512/1500):
|
| 765 |
+
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
| 766 |
+
the image length. Later layers with more crops scale down this overlap.
|
| 767 |
+
points_per_crop (`int`, *optional*, defaults to 32):
|
| 768 |
+
Number of points to sample from each crop.
|
| 769 |
+
crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1):
|
| 770 |
+
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 771 |
+
device (`torch.device`, *optional*, defaults to None):
|
| 772 |
+
Device to use for the computation. If None, cpu will be used.
|
| 773 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 774 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 775 |
+
return_tensors (`str`, *optional*, defaults to `pt`):
|
| 776 |
+
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
| 777 |
+
"""
|
| 778 |
+
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
|
| 779 |
+
image,
|
| 780 |
+
target_size,
|
| 781 |
+
crop_n_layers,
|
| 782 |
+
overlap_ratio,
|
| 783 |
+
points_per_crop,
|
| 784 |
+
crop_n_points_downscale_factor,
|
| 785 |
+
input_data_format,
|
| 786 |
+
)
|
| 787 |
+
if return_tensors == "pt":
|
| 788 |
+
if device is None:
|
| 789 |
+
device = torch.device("cpu")
|
| 790 |
+
crop_boxes = torch.tensor(crop_boxes, device=device)
|
| 791 |
+
points_per_crop = torch.tensor(points_per_crop, device=device)
|
| 792 |
+
# cropped_images stays as np
|
| 793 |
+
input_labels = torch.tensor(input_labels, device=device)
|
| 794 |
+
|
| 795 |
+
elif return_tensors == "tf":
|
| 796 |
+
if device is not None:
|
| 797 |
+
raise ValueError("device is not a supported argument when return_tensors is tf!")
|
| 798 |
+
crop_boxes = tf.convert_to_tensor(crop_boxes)
|
| 799 |
+
points_per_crop = tf.convert_to_tensor(points_per_crop)
|
| 800 |
+
# cropped_images stays as np
|
| 801 |
+
input_labels = tf.convert_to_tensor(input_labels)
|
| 802 |
+
else:
|
| 803 |
+
raise ValueError("return_tensors must be either 'pt' or 'tf'.")
|
| 804 |
+
return crop_boxes, points_per_crop, cropped_images, input_labels
|
| 805 |
+
|
| 806 |
+
def filter_masks(
|
| 807 |
+
self,
|
| 808 |
+
masks,
|
| 809 |
+
iou_scores,
|
| 810 |
+
original_size,
|
| 811 |
+
cropped_box_image,
|
| 812 |
+
pred_iou_thresh=0.88,
|
| 813 |
+
stability_score_thresh=0.95,
|
| 814 |
+
mask_threshold=0,
|
| 815 |
+
stability_score_offset=1,
|
| 816 |
+
return_tensors="pt",
|
| 817 |
+
):
|
| 818 |
+
"""
|
| 819 |
+
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
| 820 |
+
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
| 821 |
+
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
| 822 |
+
bounding boxes and pad the predicted masks if necessary.
|
| 823 |
+
|
| 824 |
+
Args:
|
| 825 |
+
masks (`Union[torch.Tensor, tf.Tensor]`):
|
| 826 |
+
Input masks.
|
| 827 |
+
iou_scores (`Union[torch.Tensor, tf.Tensor]`):
|
| 828 |
+
List of IoU scores.
|
| 829 |
+
original_size (`Tuple[int,int]`):
|
| 830 |
+
Size of the orginal image.
|
| 831 |
+
cropped_box_image (`np.array`):
|
| 832 |
+
The cropped image.
|
| 833 |
+
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
| 834 |
+
The threshold for the iou scores.
|
| 835 |
+
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
| 836 |
+
The threshold for the stability score.
|
| 837 |
+
mask_threshold (`float`, *optional*, defaults to 0):
|
| 838 |
+
The threshold for the predicted masks.
|
| 839 |
+
stability_score_offset (`float`, *optional*, defaults to 1):
|
| 840 |
+
The offset for the stability score used in the `_compute_stability_score` method.
|
| 841 |
+
return_tensors (`str`, *optional*, defaults to `pt`):
|
| 842 |
+
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
| 843 |
+
"""
|
| 844 |
+
if return_tensors == "pt":
|
| 845 |
+
return self._filter_masks_pt(
|
| 846 |
+
masks=masks,
|
| 847 |
+
iou_scores=iou_scores,
|
| 848 |
+
original_size=original_size,
|
| 849 |
+
cropped_box_image=cropped_box_image,
|
| 850 |
+
pred_iou_thresh=pred_iou_thresh,
|
| 851 |
+
stability_score_thresh=stability_score_thresh,
|
| 852 |
+
mask_threshold=mask_threshold,
|
| 853 |
+
stability_score_offset=stability_score_offset,
|
| 854 |
+
)
|
| 855 |
+
elif return_tensors == "tf":
|
| 856 |
+
return self._filter_masks_tf(
|
| 857 |
+
masks=masks,
|
| 858 |
+
iou_scores=iou_scores,
|
| 859 |
+
original_size=original_size,
|
| 860 |
+
cropped_box_image=cropped_box_image,
|
| 861 |
+
pred_iou_thresh=pred_iou_thresh,
|
| 862 |
+
stability_score_thresh=stability_score_thresh,
|
| 863 |
+
mask_threshold=mask_threshold,
|
| 864 |
+
stability_score_offset=stability_score_offset,
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
def _filter_masks_pt(
|
| 868 |
+
self,
|
| 869 |
+
masks,
|
| 870 |
+
iou_scores,
|
| 871 |
+
original_size,
|
| 872 |
+
cropped_box_image,
|
| 873 |
+
pred_iou_thresh=0.88,
|
| 874 |
+
stability_score_thresh=0.95,
|
| 875 |
+
mask_threshold=0,
|
| 876 |
+
stability_score_offset=1,
|
| 877 |
+
):
|
| 878 |
+
"""
|
| 879 |
+
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
| 880 |
+
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
| 881 |
+
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
| 882 |
+
bounding boxes and pad the predicted masks if necessary.
|
| 883 |
+
|
| 884 |
+
Args:
|
| 885 |
+
masks (`torch.Tensor`):
|
| 886 |
+
Input masks.
|
| 887 |
+
iou_scores (`torch.Tensor`):
|
| 888 |
+
List of IoU scores.
|
| 889 |
+
original_size (`Tuple[int,int]`):
|
| 890 |
+
Size of the orginal image.
|
| 891 |
+
cropped_box_image (`np.array`):
|
| 892 |
+
The cropped image.
|
| 893 |
+
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
| 894 |
+
The threshold for the iou scores.
|
| 895 |
+
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
| 896 |
+
The threshold for the stability score.
|
| 897 |
+
mask_threshold (`float`, *optional*, defaults to 0):
|
| 898 |
+
The threshold for the predicted masks.
|
| 899 |
+
stability_score_offset (`float`, *optional*, defaults to 1):
|
| 900 |
+
The offset for the stability score used in the `_compute_stability_score` method.
|
| 901 |
+
|
| 902 |
+
"""
|
| 903 |
+
requires_backends(self, ["torch"])
|
| 904 |
+
original_height, original_width = original_size
|
| 905 |
+
iou_scores = iou_scores.flatten(0, 1)
|
| 906 |
+
masks = masks.flatten(0, 1)
|
| 907 |
+
|
| 908 |
+
if masks.shape[0] != iou_scores.shape[0]:
|
| 909 |
+
raise ValueError("masks and iou_scores must have the same batch size.")
|
| 910 |
+
|
| 911 |
+
if masks.device != iou_scores.device:
|
| 912 |
+
iou_scores = iou_scores.to(masks.device)
|
| 913 |
+
|
| 914 |
+
batch_size = masks.shape[0]
|
| 915 |
+
|
| 916 |
+
keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
|
| 917 |
+
|
| 918 |
+
if pred_iou_thresh > 0.0:
|
| 919 |
+
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
| 920 |
+
|
| 921 |
+
# compute stability score
|
| 922 |
+
if stability_score_thresh > 0.0:
|
| 923 |
+
stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
|
| 924 |
+
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
| 925 |
+
|
| 926 |
+
scores = iou_scores[keep_mask]
|
| 927 |
+
masks = masks[keep_mask]
|
| 928 |
+
|
| 929 |
+
# binarize masks
|
| 930 |
+
masks = masks > mask_threshold
|
| 931 |
+
converted_boxes = _batched_mask_to_box(masks)
|
| 932 |
+
|
| 933 |
+
keep_mask = ~_is_box_near_crop_edge(
|
| 934 |
+
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
scores = scores[keep_mask]
|
| 938 |
+
masks = masks[keep_mask]
|
| 939 |
+
converted_boxes = converted_boxes[keep_mask]
|
| 940 |
+
|
| 941 |
+
masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
|
| 942 |
+
# conversion to rle is necessary to run non-maximum suppresion
|
| 943 |
+
masks = _mask_to_rle_pytorch(masks)
|
| 944 |
+
|
| 945 |
+
return masks, scores, converted_boxes
|
| 946 |
+
|
| 947 |
+
def _filter_masks_tf(
|
| 948 |
+
self,
|
| 949 |
+
masks,
|
| 950 |
+
iou_scores,
|
| 951 |
+
original_size,
|
| 952 |
+
cropped_box_image,
|
| 953 |
+
pred_iou_thresh=0.88,
|
| 954 |
+
stability_score_thresh=0.95,
|
| 955 |
+
mask_threshold=0,
|
| 956 |
+
stability_score_offset=1,
|
| 957 |
+
):
|
| 958 |
+
"""
|
| 959 |
+
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
| 960 |
+
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
| 961 |
+
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
| 962 |
+
bounding boxes and pad the predicted masks if necessary.
|
| 963 |
+
|
| 964 |
+
Args:
|
| 965 |
+
masks (`tf.Tensor`):
|
| 966 |
+
Input masks.
|
| 967 |
+
iou_scores (`tf.Tensor`):
|
| 968 |
+
List of IoU scores.
|
| 969 |
+
original_size (`Tuple[int,int]`):
|
| 970 |
+
Size of the orginal image.
|
| 971 |
+
cropped_box_image (`np.array`):
|
| 972 |
+
The cropped image.
|
| 973 |
+
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
| 974 |
+
The threshold for the iou scores.
|
| 975 |
+
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
| 976 |
+
The threshold for the stability score.
|
| 977 |
+
mask_threshold (`float`, *optional*, defaults to 0):
|
| 978 |
+
The threshold for the predicted masks.
|
| 979 |
+
stability_score_offset (`float`, *optional*, defaults to 1):
|
| 980 |
+
The offset for the stability score used in the `_compute_stability_score` method.
|
| 981 |
+
|
| 982 |
+
"""
|
| 983 |
+
requires_backends(self, ["tf"])
|
| 984 |
+
original_height, original_width = original_size
|
| 985 |
+
iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
|
| 986 |
+
masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
|
| 987 |
+
|
| 988 |
+
if masks.shape[0] != iou_scores.shape[0]:
|
| 989 |
+
raise ValueError("masks and iou_scores must have the same batch size.")
|
| 990 |
+
|
| 991 |
+
batch_size = masks.shape[0]
|
| 992 |
+
|
| 993 |
+
keep_mask = tf.ones(batch_size, dtype=tf.bool)
|
| 994 |
+
|
| 995 |
+
if pred_iou_thresh > 0.0:
|
| 996 |
+
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
| 997 |
+
|
| 998 |
+
# compute stability score
|
| 999 |
+
if stability_score_thresh > 0.0:
|
| 1000 |
+
stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
|
| 1001 |
+
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
| 1002 |
+
|
| 1003 |
+
scores = iou_scores[keep_mask]
|
| 1004 |
+
masks = masks[keep_mask]
|
| 1005 |
+
|
| 1006 |
+
# binarize masks
|
| 1007 |
+
masks = masks > mask_threshold
|
| 1008 |
+
converted_boxes = _batched_mask_to_box_tf(masks)
|
| 1009 |
+
|
| 1010 |
+
keep_mask = ~_is_box_near_crop_edge_tf(
|
| 1011 |
+
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
scores = scores[keep_mask]
|
| 1015 |
+
masks = masks[keep_mask]
|
| 1016 |
+
converted_boxes = converted_boxes[keep_mask]
|
| 1017 |
+
|
| 1018 |
+
masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
|
| 1019 |
+
# conversion to rle is necessary to run non-maximum suppresion
|
| 1020 |
+
masks = _mask_to_rle_tf(masks)
|
| 1021 |
+
|
| 1022 |
+
return masks, scores, converted_boxes
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
| 1026 |
+
# One mask is always contained inside the other.
|
| 1027 |
+
# Save memory by preventing unnecesary cast to torch.int64
|
| 1028 |
+
intersections = (
|
| 1029 |
+
(masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
| 1030 |
+
)
|
| 1031 |
+
unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
| 1032 |
+
stability_scores = intersections / unions
|
| 1033 |
+
return stability_scores
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
|
| 1037 |
+
# Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure
|
| 1038 |
+
# we get the right division results.
|
| 1039 |
+
intersections = tf.count_nonzero(
|
| 1040 |
+
masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
|
| 1041 |
+
)
|
| 1042 |
+
unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
|
| 1043 |
+
stability_scores = intersections / unions
|
| 1044 |
+
return stability_scores
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
def _build_point_grid(n_per_side: int) -> np.ndarray:
|
| 1048 |
+
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
| 1049 |
+
offset = 1 / (2 * n_per_side)
|
| 1050 |
+
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
| 1051 |
+
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
| 1052 |
+
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
| 1053 |
+
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
| 1054 |
+
return points
|
| 1055 |
+
|
| 1056 |
+
|
| 1057 |
+
def _normalize_coordinates(
|
| 1058 |
+
target_size: int, coords: np.ndarray, original_size: Tuple[int, int], is_bounding_box=False
|
| 1059 |
+
) -> np.ndarray:
|
| 1060 |
+
"""
|
| 1061 |
+
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
|
| 1062 |
+
format.
|
| 1063 |
+
"""
|
| 1064 |
+
old_height, old_width = original_size
|
| 1065 |
+
|
| 1066 |
+
scale = target_size * 1.0 / max(old_height, old_width)
|
| 1067 |
+
new_height, new_width = old_height * scale, old_width * scale
|
| 1068 |
+
new_width = int(new_width + 0.5)
|
| 1069 |
+
new_height = int(new_height + 0.5)
|
| 1070 |
+
|
| 1071 |
+
coords = deepcopy(coords).astype(float)
|
| 1072 |
+
|
| 1073 |
+
if is_bounding_box:
|
| 1074 |
+
coords = coords.reshape(-1, 2, 2)
|
| 1075 |
+
|
| 1076 |
+
coords[..., 0] = coords[..., 0] * (new_width / old_width)
|
| 1077 |
+
coords[..., 1] = coords[..., 1] * (new_height / old_height)
|
| 1078 |
+
|
| 1079 |
+
if is_bounding_box:
|
| 1080 |
+
coords = coords.reshape(-1, 4)
|
| 1081 |
+
|
| 1082 |
+
return coords
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
def _generate_crop_boxes(
|
| 1086 |
+
image,
|
| 1087 |
+
target_size: int, # Is it tuple here?
|
| 1088 |
+
crop_n_layers: int = 0,
|
| 1089 |
+
overlap_ratio: float = 512 / 1500,
|
| 1090 |
+
points_per_crop: Optional[int] = 32,
|
| 1091 |
+
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
| 1092 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 1093 |
+
) -> Tuple[List[List[int]], List[int]]:
|
| 1094 |
+
"""
|
| 1095 |
+
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
| 1096 |
+
|
| 1097 |
+
Args:
|
| 1098 |
+
image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
|
| 1099 |
+
Image to generate crops for.
|
| 1100 |
+
target_size (`int`):
|
| 1101 |
+
Size of the smallest crop.
|
| 1102 |
+
crop_n_layers (`int`, *optional*):
|
| 1103 |
+
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
|
| 1104 |
+
to run, where each layer has 2**i_layer number of image crops.
|
| 1105 |
+
overlap_ratio (`int`, *optional*):
|
| 1106 |
+
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
|
| 1107 |
+
image length. Later layers with more crops scale down this overlap.
|
| 1108 |
+
points_per_crop (`int`, *optional*):
|
| 1109 |
+
Number of points to sample per crop.
|
| 1110 |
+
crop_n_points_downscale_factor (`int`, *optional*):
|
| 1111 |
+
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 1112 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 1113 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 1114 |
+
"""
|
| 1115 |
+
|
| 1116 |
+
if isinstance(image, list):
|
| 1117 |
+
raise ValueError("Only one image is allowed for crop generation.")
|
| 1118 |
+
image = to_numpy_array(image)
|
| 1119 |
+
original_size = get_image_size(image, input_data_format)
|
| 1120 |
+
|
| 1121 |
+
points_grid = []
|
| 1122 |
+
for i in range(crop_n_layers + 1):
|
| 1123 |
+
n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
|
| 1124 |
+
points_grid.append(_build_point_grid(n_points))
|
| 1125 |
+
|
| 1126 |
+
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
|
| 1127 |
+
|
| 1128 |
+
cropped_images, point_grid_per_crop = _generate_crop_images(
|
| 1129 |
+
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format
|
| 1130 |
+
)
|
| 1131 |
+
crop_boxes = np.array(crop_boxes)
|
| 1132 |
+
crop_boxes = crop_boxes.astype(np.float32)
|
| 1133 |
+
points_per_crop = np.array([point_grid_per_crop])
|
| 1134 |
+
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
|
| 1135 |
+
|
| 1136 |
+
input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
|
| 1137 |
+
|
| 1138 |
+
return crop_boxes, points_per_crop, cropped_images, input_labels
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
|
| 1142 |
+
"""
|
| 1143 |
+
Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
|
| 1144 |
+
consists of the following required indices:
|
| 1145 |
+
- X: X coordinate of the top left of the bounding box
|
| 1146 |
+
- Y: Y coordinate of the top left of the bounding box
|
| 1147 |
+
- W: width of the bounding box
|
| 1148 |
+
- H: height of the bounding box
|
| 1149 |
+
"""
|
| 1150 |
+
crop_boxes, layer_idxs = [], []
|
| 1151 |
+
im_height, im_width = original_size
|
| 1152 |
+
short_side = min(im_height, im_width)
|
| 1153 |
+
|
| 1154 |
+
# Original image
|
| 1155 |
+
crop_boxes.append([0, 0, im_width, im_height])
|
| 1156 |
+
layer_idxs.append(0)
|
| 1157 |
+
for i_layer in range(crop_n_layers):
|
| 1158 |
+
n_crops_per_side = 2 ** (i_layer + 1)
|
| 1159 |
+
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
| 1160 |
+
|
| 1161 |
+
crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
|
| 1162 |
+
crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
|
| 1163 |
+
|
| 1164 |
+
crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
|
| 1165 |
+
crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
|
| 1166 |
+
|
| 1167 |
+
for left, top in product(crop_box_x0, crop_box_y0):
|
| 1168 |
+
box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
|
| 1169 |
+
crop_boxes.append(box)
|
| 1170 |
+
layer_idxs.append(i_layer + 1)
|
| 1171 |
+
|
| 1172 |
+
return crop_boxes, layer_idxs
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
def _generate_crop_images(
|
| 1176 |
+
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
|
| 1177 |
+
):
|
| 1178 |
+
"""
|
| 1179 |
+
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
|
| 1180 |
+
also passed.
|
| 1181 |
+
"""
|
| 1182 |
+
cropped_images = []
|
| 1183 |
+
total_points_per_crop = []
|
| 1184 |
+
for i, crop_box in enumerate(crop_boxes):
|
| 1185 |
+
left, top, right, bottom = crop_box
|
| 1186 |
+
|
| 1187 |
+
channel_dim = infer_channel_dimension_format(image, input_data_format)
|
| 1188 |
+
if channel_dim == ChannelDimension.LAST:
|
| 1189 |
+
cropped_im = image[top:bottom, left:right, :]
|
| 1190 |
+
else:
|
| 1191 |
+
cropped_im = image[:, top:bottom, left:right]
|
| 1192 |
+
|
| 1193 |
+
cropped_images.append(cropped_im)
|
| 1194 |
+
|
| 1195 |
+
cropped_im_size = get_image_size(cropped_im, channel_dim)
|
| 1196 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
| 1197 |
+
|
| 1198 |
+
points = points_grid[layer_idxs[i]] * points_scale
|
| 1199 |
+
normalized_points = _normalize_coordinates(target_size, points, original_size)
|
| 1200 |
+
total_points_per_crop.append(normalized_points)
|
| 1201 |
+
|
| 1202 |
+
return cropped_images, total_points_per_crop
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
|
| 1206 |
+
left, top, right, bottom = crop_box
|
| 1207 |
+
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
| 1208 |
+
return masks
|
| 1209 |
+
# Coordinate transform masks
|
| 1210 |
+
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
| 1211 |
+
pad = (left, pad_x - left, top, pad_y - top)
|
| 1212 |
+
return torch.nn.functional.pad(masks, pad, value=0)
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int):
|
| 1216 |
+
left, top, right, bottom = crop_box
|
| 1217 |
+
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
| 1218 |
+
return masks
|
| 1219 |
+
# Coordinate transform masks
|
| 1220 |
+
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
| 1221 |
+
pad = (left, pad_x - left, top, pad_y - top)
|
| 1222 |
+
return tf.pad(masks, pad, constant_values=0)
|
| 1223 |
+
|
| 1224 |
+
|
| 1225 |
+
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
| 1226 |
+
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
| 1227 |
+
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
| 1228 |
+
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
| 1229 |
+
|
| 1230 |
+
left, top, _, _ = crop_box
|
| 1231 |
+
offset = torch.tensor([[left, top, left, top]], device=boxes.device)
|
| 1232 |
+
# Check if boxes has a channel dimension
|
| 1233 |
+
if len(boxes.shape) == 3:
|
| 1234 |
+
offset = offset.unsqueeze(1)
|
| 1235 |
+
boxes = (boxes + offset).float()
|
| 1236 |
+
|
| 1237 |
+
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
| 1238 |
+
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
| 1239 |
+
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
| 1240 |
+
return torch.any(near_crop_edge, dim=1)
|
| 1241 |
+
|
| 1242 |
+
|
| 1243 |
+
def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
|
| 1244 |
+
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
| 1245 |
+
crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
|
| 1246 |
+
orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
|
| 1247 |
+
|
| 1248 |
+
left, top, _, _ = crop_box
|
| 1249 |
+
offset = tf.convert_to_tensor([[left, top, left, top]])
|
| 1250 |
+
# Check if boxes has a channel dimension
|
| 1251 |
+
if len(boxes.shape) == 3:
|
| 1252 |
+
offset = tf.expand_dims(offset, 1)
|
| 1253 |
+
boxes = tf.cast(boxes + offset, tf.float32)
|
| 1254 |
+
|
| 1255 |
+
near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
|
| 1256 |
+
near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
|
| 1257 |
+
near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
|
| 1258 |
+
return tf.reduce_any(near_crop_edge, axis=1)
|
| 1259 |
+
|
| 1260 |
+
|
| 1261 |
+
def _batched_mask_to_box(masks: "torch.Tensor"):
|
| 1262 |
+
"""
|
| 1263 |
+
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
| 1264 |
+
corresponds the following required indices:
|
| 1265 |
+
- LEFT: left hand side of the bounding box
|
| 1266 |
+
- TOP: top of the bounding box
|
| 1267 |
+
- RIGHT: right of the bounding box
|
| 1268 |
+
- BOTTOM: bottom of the bounding box
|
| 1269 |
+
|
| 1270 |
+
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
| 1271 |
+
is channel_1 x channel_2 x ... x 4.
|
| 1272 |
+
|
| 1273 |
+
Args:
|
| 1274 |
+
- masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
|
| 1275 |
+
"""
|
| 1276 |
+
# torch.max below raises an error on empty inputs, just skip in this case
|
| 1277 |
+
|
| 1278 |
+
if torch.numel(masks) == 0:
|
| 1279 |
+
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
| 1280 |
+
|
| 1281 |
+
# Normalize shape to Cxheightxwidth
|
| 1282 |
+
shape = masks.shape
|
| 1283 |
+
height, width = shape[-2:]
|
| 1284 |
+
|
| 1285 |
+
# Get top and bottom edges
|
| 1286 |
+
in_height, _ = torch.max(masks, dim=-1)
|
| 1287 |
+
in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
|
| 1288 |
+
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
| 1289 |
+
in_height_coords = in_height_coords + height * (~in_height)
|
| 1290 |
+
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
| 1291 |
+
|
| 1292 |
+
# Get left and right edges
|
| 1293 |
+
in_width, _ = torch.max(masks, dim=-2)
|
| 1294 |
+
in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
|
| 1295 |
+
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
| 1296 |
+
in_width_coords = in_width_coords + width * (~in_width)
|
| 1297 |
+
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
| 1298 |
+
|
| 1299 |
+
# If the mask is empty the right edge will be to the left of the left edge.
|
| 1300 |
+
# Replace these boxes with [0, 0, 0, 0]
|
| 1301 |
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
| 1302 |
+
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
| 1303 |
+
out = out * (~empty_filter).unsqueeze(-1)
|
| 1304 |
+
|
| 1305 |
+
# Return to original shape
|
| 1306 |
+
out = out.reshape(*shape[:-2], 4)
|
| 1307 |
+
return out
|
| 1308 |
+
|
| 1309 |
+
|
| 1310 |
+
def _batched_mask_to_box_tf(masks: "tf.Tensor"):
|
| 1311 |
+
"""
|
| 1312 |
+
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
| 1313 |
+
corresponds the following required indices:
|
| 1314 |
+
- LEFT: left hand side of the bounding box
|
| 1315 |
+
- TOP: top of the bounding box
|
| 1316 |
+
- RIGHT: right of the bounding box
|
| 1317 |
+
- BOTTOM: bottom of the bounding box
|
| 1318 |
+
|
| 1319 |
+
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
| 1320 |
+
is channel_1 x channel_2 x ... x 4.
|
| 1321 |
+
|
| 1322 |
+
Args:
|
| 1323 |
+
- masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`)
|
| 1324 |
+
"""
|
| 1325 |
+
|
| 1326 |
+
if tf.size(masks) == 0:
|
| 1327 |
+
return tf.zeros([*masks.shape[:-2], 4])
|
| 1328 |
+
|
| 1329 |
+
# Normalize shape to Cxheightxwidth
|
| 1330 |
+
shape = shape_list(masks)
|
| 1331 |
+
height, width = shape[-2:]
|
| 1332 |
+
|
| 1333 |
+
# Get top and bottom edges
|
| 1334 |
+
in_height = tf.reduce_max(masks, axis=-1)
|
| 1335 |
+
in_height_coords = in_height * tf.range(height)[None, :]
|
| 1336 |
+
bottom_edges = tf.reduce_max(in_height_coords, axis=-1)
|
| 1337 |
+
in_height_coords = in_height_coords + height * (~in_height)
|
| 1338 |
+
top_edges = tf.reduce_min(in_height_coords, axis=-1)
|
| 1339 |
+
|
| 1340 |
+
# Get left and right edges
|
| 1341 |
+
in_width, _ = tf.reduce_max(masks, axis=-2)
|
| 1342 |
+
in_width_coords = in_width * tf.range(width)[None, :]
|
| 1343 |
+
right_edges, _ = tf.reduce_max(in_width_coords, axis=-1)
|
| 1344 |
+
in_width_coords = in_width_coords + width * (~in_width)
|
| 1345 |
+
left_edges, _ = tf.reduce_min(in_width_coords, axis=-1)
|
| 1346 |
+
|
| 1347 |
+
# If the mask is empty the right edge will be to the left of the left edge.
|
| 1348 |
+
# Replace these boxes with [0, 0, 0, 0]
|
| 1349 |
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
| 1350 |
+
out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
|
| 1351 |
+
out = out * tf.expand_dims(~empty_filter, -1)
|
| 1352 |
+
|
| 1353 |
+
# Return to original shape
|
| 1354 |
+
out = tf.reshape(out, *shape[:-2], 4)
|
| 1355 |
+
return out
|
| 1356 |
+
|
| 1357 |
+
|
| 1358 |
+
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
| 1359 |
+
"""
|
| 1360 |
+
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
| 1361 |
+
"""
|
| 1362 |
+
# Put in fortran order and flatten height and width
|
| 1363 |
+
batch_size, height, width = input_mask.shape
|
| 1364 |
+
input_mask = input_mask.permute(0, 2, 1).flatten(1)
|
| 1365 |
+
|
| 1366 |
+
# Compute change indices
|
| 1367 |
+
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
| 1368 |
+
change_indices = diff.nonzero()
|
| 1369 |
+
|
| 1370 |
+
# Encode run length
|
| 1371 |
+
out = []
|
| 1372 |
+
for i in range(batch_size):
|
| 1373 |
+
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
| 1374 |
+
if len(cur_idxs) == 0:
|
| 1375 |
+
# No changes => either all 0 or all 1
|
| 1376 |
+
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
|
| 1377 |
+
if input_mask[i, 0] == 0:
|
| 1378 |
+
out.append({"size": [height, width], "counts": [height * width]})
|
| 1379 |
+
else:
|
| 1380 |
+
out.append({"size": [height, width], "counts": [0, height * width]})
|
| 1381 |
+
continue
|
| 1382 |
+
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
| 1383 |
+
counts = [] if input_mask[i, 0] == 0 else [0]
|
| 1384 |
+
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
|
| 1385 |
+
out.append({"size": [height, width], "counts": counts})
|
| 1386 |
+
return out
|
| 1387 |
+
|
| 1388 |
+
|
| 1389 |
+
def _mask_to_rle_tf(input_mask: "tf.Tensor"):
|
| 1390 |
+
"""
|
| 1391 |
+
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
| 1392 |
+
"""
|
| 1393 |
+
# Put in fortran order and flatten height and width
|
| 1394 |
+
batch_size, height, width = input_mask.shape
|
| 1395 |
+
input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
|
| 1396 |
+
|
| 1397 |
+
# Compute change indices
|
| 1398 |
+
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
| 1399 |
+
change_indices = tf.where(diff)
|
| 1400 |
+
|
| 1401 |
+
# Encode run length
|
| 1402 |
+
out = []
|
| 1403 |
+
for i in range(batch_size):
|
| 1404 |
+
cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1
|
| 1405 |
+
if len(cur_idxs) == 0:
|
| 1406 |
+
# No changes => either all 0 or all 1
|
| 1407 |
+
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
|
| 1408 |
+
if input_mask[i, 0] == 0:
|
| 1409 |
+
out.append({"size": [height, width], "counts": [height * width]})
|
| 1410 |
+
else:
|
| 1411 |
+
out.append({"size": [height, width], "counts": [0, height * width]})
|
| 1412 |
+
continue
|
| 1413 |
+
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
| 1414 |
+
counts = [] if input_mask[i, 0] == 0 else [0]
|
| 1415 |
+
counts += (
|
| 1416 |
+
[cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()]
|
| 1417 |
+
)
|
| 1418 |
+
out.append({"size": [height, width], "counts": counts})
|
| 1419 |
+
return out
|
| 1420 |
+
|
| 1421 |
+
|
| 1422 |
+
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
| 1423 |
+
"""Compute a binary mask from an uncompressed RLE."""
|
| 1424 |
+
height, width = rle["size"]
|
| 1425 |
+
mask = np.empty(height * width, dtype=bool)
|
| 1426 |
+
idx = 0
|
| 1427 |
+
parity = False
|
| 1428 |
+
for count in rle["counts"]:
|
| 1429 |
+
mask[idx : idx + count] = parity
|
| 1430 |
+
idx += count
|
| 1431 |
+
parity = not parity
|
| 1432 |
+
mask = mask.reshape(width, height)
|
| 1433 |
+
return mask.transpose() # Reshape to original shape
|
| 1434 |
+
|
| 1435 |
+
|
| 1436 |
+
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
| 1437 |
+
"""
|
| 1438 |
+
Perform NMS (Non Maximum Suppression) on the outputs.
|
| 1439 |
+
|
| 1440 |
+
Args:
|
| 1441 |
+
rle_masks (`torch.Tensor`):
|
| 1442 |
+
binary masks in the RLE format
|
| 1443 |
+
iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
|
| 1444 |
+
iou_scores predicted by the model
|
| 1445 |
+
mask_boxes (`torch.Tensor`):
|
| 1446 |
+
The bounding boxes corresponding to segmentation masks
|
| 1447 |
+
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
| 1448 |
+
NMS threshold.
|
| 1449 |
+
"""
|
| 1450 |
+
keep_by_nms = batched_nms(
|
| 1451 |
+
boxes=mask_boxes.float(),
|
| 1452 |
+
scores=iou_scores,
|
| 1453 |
+
idxs=torch.zeros(mask_boxes.shape[0]),
|
| 1454 |
+
iou_threshold=amg_crops_nms_thresh,
|
| 1455 |
+
)
|
| 1456 |
+
|
| 1457 |
+
iou_scores = iou_scores[keep_by_nms]
|
| 1458 |
+
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
| 1459 |
+
mask_boxes = mask_boxes[keep_by_nms]
|
| 1460 |
+
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
| 1461 |
+
|
| 1462 |
+
return masks, iou_scores, rle_masks, mask_boxes
|
| 1463 |
+
|
| 1464 |
+
|
| 1465 |
+
def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
| 1466 |
+
"""
|
| 1467 |
+
Perform NMS (Non Maximum Suppression) on the outputs.
|
| 1468 |
+
|
| 1469 |
+
Args:
|
| 1470 |
+
rle_masks (`tf.Tensor`):
|
| 1471 |
+
binary masks in the RLE format
|
| 1472 |
+
iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
|
| 1473 |
+
iou_scores predicted by the model
|
| 1474 |
+
mask_boxes (`tf.Tensor`):
|
| 1475 |
+
The bounding boxes corresponding to segmentation masks
|
| 1476 |
+
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
| 1477 |
+
NMS threshold.
|
| 1478 |
+
"""
|
| 1479 |
+
keep_by_nms = tf.image.combined_non_max_suppression(
|
| 1480 |
+
boxes=mask_boxes.float(),
|
| 1481 |
+
scores=iou_scores,
|
| 1482 |
+
idxs=torch.zeros(mask_boxes.shape[0]),
|
| 1483 |
+
iou_threshold=amg_crops_nms_thresh,
|
| 1484 |
+
)
|
| 1485 |
+
|
| 1486 |
+
iou_scores = iou_scores[keep_by_nms]
|
| 1487 |
+
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
| 1488 |
+
mask_boxes = mask_boxes[keep_by_nms]
|
| 1489 |
+
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
| 1490 |
+
|
| 1491 |
+
return masks, iou_scores, rle_masks, mask_boxes
|
| 1492 |
+
|
| 1493 |
+
|
| 1494 |
+
__all__ = ["SamImageProcessor"]
|
docs/transformers/build/lib/transformers/models/sam/modeling_sam.py
ADDED
|
@@ -0,0 +1,1579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch SAM model."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import Tensor, nn
|
| 26 |
+
|
| 27 |
+
from ...activations import ACT2FN
|
| 28 |
+
from ...modeling_outputs import BaseModelOutput
|
| 29 |
+
from ...modeling_utils import PreTrainedModel
|
| 30 |
+
from ...utils import (
|
| 31 |
+
ModelOutput,
|
| 32 |
+
add_start_docstrings,
|
| 33 |
+
add_start_docstrings_to_model_forward,
|
| 34 |
+
can_return_tuple,
|
| 35 |
+
logging,
|
| 36 |
+
replace_return_docstrings,
|
| 37 |
+
)
|
| 38 |
+
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
_CONFIG_FOR_DOC = "SamConfig"
|
| 44 |
+
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class SamVisionEncoderOutput(ModelOutput):
|
| 49 |
+
"""
|
| 50 |
+
Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
|
| 51 |
+
layer to the pooler_output.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 55 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
| 56 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 57 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 58 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 59 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 60 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 61 |
+
|
| 62 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 63 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 64 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 65 |
+
sequence_length)`.
|
| 66 |
+
|
| 67 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 68 |
+
heads.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
| 72 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 73 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 74 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class SamImageSegmentationOutput(ModelOutput):
|
| 79 |
+
"""
|
| 80 |
+
Base class for Segment-Anything model's output
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
|
| 84 |
+
The iou scores of the predicted masks.
|
| 85 |
+
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
|
| 86 |
+
The predicted low resolutions masks. Needs to be post-processed by the processor
|
| 87 |
+
vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 88 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 89 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 90 |
+
|
| 91 |
+
Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
|
| 92 |
+
vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 93 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 94 |
+
sequence_length)`.
|
| 95 |
+
|
| 96 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 97 |
+
heads.
|
| 98 |
+
mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 99 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 100 |
+
sequence_length)`.
|
| 101 |
+
|
| 102 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 103 |
+
heads.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
iou_scores: Optional[torch.FloatTensor] = None
|
| 107 |
+
pred_masks: Optional[torch.FloatTensor] = None
|
| 108 |
+
vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 109 |
+
vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 110 |
+
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class SamPatchEmbeddings(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 116 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 117 |
+
Transformer.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, config):
|
| 121 |
+
super().__init__()
|
| 122 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 123 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 124 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 125 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 126 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 127 |
+
self.image_size = image_size
|
| 128 |
+
self.patch_size = patch_size
|
| 129 |
+
self.num_channels = num_channels
|
| 130 |
+
self.num_patches = num_patches
|
| 131 |
+
|
| 132 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 133 |
+
|
| 134 |
+
def forward(self, pixel_values):
|
| 135 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 136 |
+
if num_channels != self.num_channels:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 139 |
+
)
|
| 140 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 141 |
+
raise ValueError(
|
| 142 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
| 143 |
+
)
|
| 144 |
+
embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
|
| 145 |
+
return embeddings
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class SamMLPBlock(nn.Module):
|
| 149 |
+
def __init__(self, config):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
|
| 152 |
+
self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
|
| 153 |
+
self.act = ACT2FN[config.hidden_act]
|
| 154 |
+
|
| 155 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
hidden_states = self.lin1(hidden_states)
|
| 157 |
+
hidden_states = self.act(hidden_states)
|
| 158 |
+
hidden_states = self.lin2(hidden_states)
|
| 159 |
+
return hidden_states
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
|
| 163 |
+
class SamLayerNorm(nn.Module):
|
| 164 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 165 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
|
| 166 |
+
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 172 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 173 |
+
self.eps = eps
|
| 174 |
+
self.data_format = data_format
|
| 175 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 176 |
+
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
|
| 177 |
+
self.normalized_shape = (normalized_shape,)
|
| 178 |
+
|
| 179 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 180 |
+
if self.data_format == "channels_last":
|
| 181 |
+
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 182 |
+
elif self.data_format == "channels_first":
|
| 183 |
+
input_dtype = x.dtype
|
| 184 |
+
x = x.float()
|
| 185 |
+
u = x.mean(1, keepdim=True)
|
| 186 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 187 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 188 |
+
x = x.to(dtype=input_dtype)
|
| 189 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class SamAttention(nn.Module):
|
| 194 |
+
"""
|
| 195 |
+
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
| 196 |
+
values.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, config, downsample_rate=None):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.hidden_size = config.hidden_size
|
| 202 |
+
|
| 203 |
+
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
|
| 204 |
+
|
| 205 |
+
self.internal_dim = config.hidden_size // downsample_rate
|
| 206 |
+
self.num_attention_heads = config.num_attention_heads
|
| 207 |
+
if self.internal_dim % config.num_attention_heads != 0:
|
| 208 |
+
raise ValueError("num_attention_heads must divide hidden_size.")
|
| 209 |
+
|
| 210 |
+
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
| 211 |
+
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
| 212 |
+
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
| 213 |
+
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
| 214 |
+
|
| 215 |
+
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
| 216 |
+
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
| 217 |
+
c_per_head = channel // num_attention_heads
|
| 218 |
+
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
|
| 219 |
+
return hidden_states.transpose(1, 2)
|
| 220 |
+
|
| 221 |
+
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
| 222 |
+
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
|
| 223 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 224 |
+
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
|
| 228 |
+
) -> Tensor:
|
| 229 |
+
# Input projections
|
| 230 |
+
query = self.q_proj(query)
|
| 231 |
+
key = self.k_proj(key)
|
| 232 |
+
value = self.v_proj(value)
|
| 233 |
+
|
| 234 |
+
point_batch_size = query.shape[1]
|
| 235 |
+
# Separate into heads
|
| 236 |
+
query = self._separate_heads(query, self.num_attention_heads)
|
| 237 |
+
key = self._separate_heads(key, self.num_attention_heads)
|
| 238 |
+
value = self._separate_heads(value, self.num_attention_heads)
|
| 239 |
+
|
| 240 |
+
# SamAttention
|
| 241 |
+
_, _, _, c_per_head = query.shape
|
| 242 |
+
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
|
| 243 |
+
attn = attn / (c_per_head**0.5)
|
| 244 |
+
attn = torch.softmax(attn, dim=-1)
|
| 245 |
+
|
| 246 |
+
if attention_similarity is not None:
|
| 247 |
+
attn = attn + attention_similarity
|
| 248 |
+
attn = torch.softmax(attn, dim=-1)
|
| 249 |
+
|
| 250 |
+
# Get output
|
| 251 |
+
out = attn @ value
|
| 252 |
+
out = self._recombine_heads(out, point_batch_size)
|
| 253 |
+
out = self.out_proj(out)
|
| 254 |
+
|
| 255 |
+
return out
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class SamSdpaAttention(SamAttention):
|
| 259 |
+
"""
|
| 260 |
+
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
| 261 |
+
values. Using SDPA instead of the default attention.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def __init__(self, config, downsample_rate=None):
|
| 265 |
+
super().__init__(config, downsample_rate)
|
| 266 |
+
|
| 267 |
+
def forward(
|
| 268 |
+
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
|
| 269 |
+
) -> Tensor:
|
| 270 |
+
# Input projections
|
| 271 |
+
query = self.q_proj(query)
|
| 272 |
+
key = self.k_proj(key)
|
| 273 |
+
value = self.v_proj(value)
|
| 274 |
+
|
| 275 |
+
point_batch_size = query.shape[1]
|
| 276 |
+
# Separate into heads
|
| 277 |
+
query = self._separate_heads(query, self.num_attention_heads)
|
| 278 |
+
key = self._separate_heads(key, self.num_attention_heads)
|
| 279 |
+
value = self._separate_heads(value, self.num_attention_heads)
|
| 280 |
+
|
| 281 |
+
# Scaled dot product attention
|
| 282 |
+
attn_mask = None
|
| 283 |
+
if attention_similarity is not None:
|
| 284 |
+
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
|
| 285 |
+
|
| 286 |
+
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
|
| 287 |
+
|
| 288 |
+
# Get output
|
| 289 |
+
out = self._recombine_heads(out, point_batch_size)
|
| 290 |
+
out = self.out_proj(out)
|
| 291 |
+
|
| 292 |
+
return out
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
SAM_ATTENTION_CLASSES = {
|
| 296 |
+
"eager": SamAttention,
|
| 297 |
+
"sdpa": SamSdpaAttention,
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class SamTwoWayAttentionBlock(nn.Module):
|
| 302 |
+
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
|
| 303 |
+
"""
|
| 304 |
+
A transformer block with four layers:
|
| 305 |
+
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
|
| 306 |
+
sparse inputs (4) cross attention of dense inputs -> sparse inputs
|
| 307 |
+
|
| 308 |
+
Arguments:
|
| 309 |
+
config (`SamMaskDecoderConfig`):
|
| 310 |
+
The configuration file used to instantiate the block
|
| 311 |
+
attention_downsample_rate (*optionalk*, int, defaults to 2):
|
| 312 |
+
The downsample ratio of the block used to reduce the inner dim of the attention.
|
| 313 |
+
skip_first_layer_pe (*optional*, bool, defaults to `False`):
|
| 314 |
+
Whether or not to skip the addition of the query_point_embedding on the first layer.
|
| 315 |
+
"""
|
| 316 |
+
super().__init__()
|
| 317 |
+
|
| 318 |
+
self.hidden_size = config.hidden_size
|
| 319 |
+
self.layer_norm_eps = config.layer_norm_eps
|
| 320 |
+
|
| 321 |
+
self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
|
| 322 |
+
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
| 323 |
+
|
| 324 |
+
self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
|
| 325 |
+
config, downsample_rate=attention_downsample_rate
|
| 326 |
+
)
|
| 327 |
+
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
| 328 |
+
|
| 329 |
+
self.mlp = SamMLPBlock(config)
|
| 330 |
+
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
| 331 |
+
|
| 332 |
+
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
| 333 |
+
self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
|
| 334 |
+
config, downsample_rate=attention_downsample_rate
|
| 335 |
+
)
|
| 336 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 337 |
+
|
| 338 |
+
def forward(
|
| 339 |
+
self,
|
| 340 |
+
queries: Tensor,
|
| 341 |
+
keys: Tensor,
|
| 342 |
+
query_point_embedding: Tensor,
|
| 343 |
+
key_point_embedding: Tensor,
|
| 344 |
+
attention_similarity: Tensor,
|
| 345 |
+
output_attentions: bool = False,
|
| 346 |
+
):
|
| 347 |
+
# Self attention block
|
| 348 |
+
if self.skip_first_layer_pe:
|
| 349 |
+
queries = self.self_attn(query=queries, key=queries, value=queries)
|
| 350 |
+
else:
|
| 351 |
+
query = queries + query_point_embedding
|
| 352 |
+
attn_out = self.self_attn(query=query, key=query, value=queries)
|
| 353 |
+
queries = queries + attn_out
|
| 354 |
+
queries = self.layer_norm1(queries)
|
| 355 |
+
|
| 356 |
+
# Cross attention block, tokens attending to image embedding
|
| 357 |
+
query = queries + query_point_embedding
|
| 358 |
+
key = keys + key_point_embedding
|
| 359 |
+
|
| 360 |
+
attn_out = self.cross_attn_token_to_image(
|
| 361 |
+
query=query, key=key, value=keys, attention_similarity=attention_similarity
|
| 362 |
+
)
|
| 363 |
+
queries = queries + attn_out
|
| 364 |
+
|
| 365 |
+
queries = self.layer_norm2(queries)
|
| 366 |
+
|
| 367 |
+
# MLP block
|
| 368 |
+
mlp_out = self.mlp(queries)
|
| 369 |
+
queries = queries + mlp_out
|
| 370 |
+
queries = self.layer_norm3(queries)
|
| 371 |
+
|
| 372 |
+
# Cross attention block, image embedding attending to tokens
|
| 373 |
+
query = queries + query_point_embedding
|
| 374 |
+
key = keys + key_point_embedding
|
| 375 |
+
|
| 376 |
+
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
|
| 377 |
+
keys = keys + attn_out
|
| 378 |
+
|
| 379 |
+
keys = self.layer_norm4(keys)
|
| 380 |
+
|
| 381 |
+
outputs = (queries, keys)
|
| 382 |
+
|
| 383 |
+
if output_attentions:
|
| 384 |
+
outputs = outputs + (attn_out,)
|
| 385 |
+
else:
|
| 386 |
+
outputs = outputs + (None,)
|
| 387 |
+
|
| 388 |
+
return outputs
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class SamTwoWayTransformer(nn.Module):
|
| 392 |
+
def __init__(self, config: SamMaskDecoderConfig):
|
| 393 |
+
super().__init__()
|
| 394 |
+
self.config = config
|
| 395 |
+
|
| 396 |
+
self.num_hidden_layers = config.num_hidden_layers
|
| 397 |
+
self.layers = nn.ModuleList()
|
| 398 |
+
|
| 399 |
+
for i in range(self.num_hidden_layers):
|
| 400 |
+
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
|
| 401 |
+
|
| 402 |
+
self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
|
| 403 |
+
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
|
| 404 |
+
|
| 405 |
+
def forward(
|
| 406 |
+
self,
|
| 407 |
+
point_embeddings: Tensor,
|
| 408 |
+
image_embeddings: Tensor,
|
| 409 |
+
image_positional_embeddings: Tensor,
|
| 410 |
+
attention_similarity: Tensor,
|
| 411 |
+
target_embedding=None,
|
| 412 |
+
output_attentions: Optional[bool] = None,
|
| 413 |
+
output_hidden_states: Optional[bool] = None,
|
| 414 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 415 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 416 |
+
output_hidden_states = (
|
| 417 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
all_attentions = ()
|
| 421 |
+
|
| 422 |
+
if image_embeddings is None:
|
| 423 |
+
raise ValueError("You have to specify an image_embedding")
|
| 424 |
+
|
| 425 |
+
image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
|
| 426 |
+
image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
|
| 427 |
+
|
| 428 |
+
# Prepare queries
|
| 429 |
+
queries = point_embeddings
|
| 430 |
+
keys = image_embeddings
|
| 431 |
+
|
| 432 |
+
# Apply transformer blocks and final layernorm
|
| 433 |
+
for layer in self.layers:
|
| 434 |
+
if target_embedding is not None:
|
| 435 |
+
queries += target_embedding
|
| 436 |
+
|
| 437 |
+
queries, keys, attention_outputs = layer(
|
| 438 |
+
queries=queries,
|
| 439 |
+
keys=keys,
|
| 440 |
+
query_point_embedding=point_embeddings,
|
| 441 |
+
key_point_embedding=image_positional_embeddings,
|
| 442 |
+
attention_similarity=attention_similarity,
|
| 443 |
+
output_attentions=output_attentions,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
if output_attentions:
|
| 447 |
+
all_attentions = all_attentions + (attention_outputs,)
|
| 448 |
+
|
| 449 |
+
# Apply the final attenion layer from the points to the image
|
| 450 |
+
query = queries + point_embeddings
|
| 451 |
+
key = keys + image_positional_embeddings
|
| 452 |
+
|
| 453 |
+
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
|
| 454 |
+
|
| 455 |
+
queries = queries + attn_out
|
| 456 |
+
queries = self.layer_norm_final_attn(queries)
|
| 457 |
+
return queries, keys, all_attentions
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class SamFeedForward(nn.Module):
|
| 461 |
+
def __init__(
|
| 462 |
+
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
|
| 463 |
+
):
|
| 464 |
+
super().__init__()
|
| 465 |
+
self.num_layers = num_layers
|
| 466 |
+
self.activation = nn.ReLU()
|
| 467 |
+
self.proj_in = nn.Linear(input_dim, hidden_dim)
|
| 468 |
+
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
| 469 |
+
self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
|
| 470 |
+
self.sigmoid_output = sigmoid_output
|
| 471 |
+
|
| 472 |
+
def forward(self, hidden_states):
|
| 473 |
+
hidden_states = self.proj_in(hidden_states)
|
| 474 |
+
hidden_states = self.activation(hidden_states)
|
| 475 |
+
for layer in self.layers:
|
| 476 |
+
hidden_states = self.activation(layer(hidden_states))
|
| 477 |
+
|
| 478 |
+
hidden_states = self.proj_out(hidden_states)
|
| 479 |
+
if self.sigmoid_output:
|
| 480 |
+
hidden_states = F.sigmoid(hidden_states)
|
| 481 |
+
return hidden_states
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class SamMaskDecoder(nn.Module):
|
| 485 |
+
def __init__(self, config: SamMaskDecoderConfig):
|
| 486 |
+
super().__init__()
|
| 487 |
+
self.config = config
|
| 488 |
+
self.hidden_size = config.hidden_size
|
| 489 |
+
|
| 490 |
+
self.num_multimask_outputs = config.num_multimask_outputs
|
| 491 |
+
self.num_mask_tokens = config.num_multimask_outputs + 1
|
| 492 |
+
|
| 493 |
+
self.iou_token = nn.Embedding(1, self.hidden_size)
|
| 494 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
|
| 495 |
+
|
| 496 |
+
self.transformer = SamTwoWayTransformer(config)
|
| 497 |
+
|
| 498 |
+
# should we create a new class for this?
|
| 499 |
+
self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
|
| 500 |
+
self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
|
| 501 |
+
self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
|
| 502 |
+
self.activation = nn.GELU()
|
| 503 |
+
|
| 504 |
+
mlps_list = []
|
| 505 |
+
for _ in range(self.num_mask_tokens):
|
| 506 |
+
mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
|
| 507 |
+
self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
|
| 508 |
+
|
| 509 |
+
self.iou_prediction_head = SamFeedForward(
|
| 510 |
+
self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
def forward(
|
| 514 |
+
self,
|
| 515 |
+
image_embeddings: torch.Tensor,
|
| 516 |
+
image_positional_embeddings: torch.Tensor,
|
| 517 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 518 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 519 |
+
multimask_output: bool,
|
| 520 |
+
output_attentions: Optional[bool] = None,
|
| 521 |
+
attention_similarity: Optional[torch.Tensor] = None,
|
| 522 |
+
target_embedding: Optional[torch.Tensor] = None,
|
| 523 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 524 |
+
"""
|
| 525 |
+
Predict masks given image and prompt embeddings.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
image_embeddings (`torch.Tensor`):
|
| 529 |
+
the embeddings from the image encoder
|
| 530 |
+
image_positional_embedding (`torch.Tensor`):
|
| 531 |
+
positional encoding with the shape of image_embeddings
|
| 532 |
+
sparse_prompt_embeddings (`torch.Tensor`):
|
| 533 |
+
The embeddings of the points and boxes
|
| 534 |
+
dense_prompt_embeddings (`torch.Tensor`):
|
| 535 |
+
the embeddings of the mask inputs
|
| 536 |
+
multimask_output (bool):
|
| 537 |
+
Whether to return multiple masks or a single mask.
|
| 538 |
+
output_attentions (bool, *optional*):
|
| 539 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 540 |
+
"""
|
| 541 |
+
batch_size, num_channels, height, width = image_embeddings.shape
|
| 542 |
+
point_batch_size = sparse_prompt_embeddings.shape[1]
|
| 543 |
+
# Concatenate output tokens
|
| 544 |
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
| 545 |
+
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
| 546 |
+
|
| 547 |
+
if sparse_prompt_embeddings.sum().item() != 0:
|
| 548 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
| 549 |
+
else:
|
| 550 |
+
tokens = output_tokens
|
| 551 |
+
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
| 552 |
+
|
| 553 |
+
# Expand per-image data in batch direction to be per-point
|
| 554 |
+
image_embeddings = image_embeddings + dense_prompt_embeddings
|
| 555 |
+
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
| 556 |
+
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
|
| 557 |
+
|
| 558 |
+
# Run the transformer, image_positional_embedding are consumed
|
| 559 |
+
point_embedding, image_embeddings, attentions = self.transformer(
|
| 560 |
+
point_embeddings=point_embeddings,
|
| 561 |
+
image_embeddings=image_embeddings,
|
| 562 |
+
image_positional_embeddings=image_positional_embeddings,
|
| 563 |
+
attention_similarity=attention_similarity,
|
| 564 |
+
target_embedding=target_embedding,
|
| 565 |
+
output_attentions=output_attentions,
|
| 566 |
+
)
|
| 567 |
+
iou_token_out = point_embedding[:, :, 0, :]
|
| 568 |
+
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
|
| 569 |
+
|
| 570 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
| 571 |
+
image_embeddings = image_embeddings.transpose(2, 3).reshape(
|
| 572 |
+
batch_size * point_batch_size, num_channels, height, width
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
upscaled_embedding = self.upscale_conv1(image_embeddings)
|
| 576 |
+
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
|
| 577 |
+
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
|
| 578 |
+
|
| 579 |
+
hyper_in_list = []
|
| 580 |
+
for i in range(self.num_mask_tokens):
|
| 581 |
+
current_mlp = self.output_hypernetworks_mlps[i]
|
| 582 |
+
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
|
| 583 |
+
hyper_in = torch.stack(hyper_in_list, dim=2)
|
| 584 |
+
|
| 585 |
+
_, num_channels, height, width = upscaled_embedding.shape
|
| 586 |
+
upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
|
| 587 |
+
masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
|
| 588 |
+
|
| 589 |
+
# Generate mask quality predictions
|
| 590 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
| 591 |
+
|
| 592 |
+
# Select the correct mask or masks for output
|
| 593 |
+
if multimask_output:
|
| 594 |
+
mask_slice = slice(1, None)
|
| 595 |
+
else:
|
| 596 |
+
mask_slice = slice(0, 1)
|
| 597 |
+
masks = masks[:, :, mask_slice, :, :]
|
| 598 |
+
iou_pred = iou_pred[:, :, mask_slice]
|
| 599 |
+
|
| 600 |
+
outputs = (masks, iou_pred)
|
| 601 |
+
|
| 602 |
+
if output_attentions:
|
| 603 |
+
outputs = outputs + (attentions,)
|
| 604 |
+
else:
|
| 605 |
+
outputs = outputs + (None,)
|
| 606 |
+
|
| 607 |
+
return outputs
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class SamPositionalEmbedding(nn.Module):
|
| 611 |
+
def __init__(self, config):
|
| 612 |
+
super().__init__()
|
| 613 |
+
self.scale = config.hidden_size // 2
|
| 614 |
+
self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
|
| 615 |
+
|
| 616 |
+
def forward(self, input_coords, input_shape=None):
|
| 617 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 618 |
+
coordinates = input_coords.clone()
|
| 619 |
+
|
| 620 |
+
if input_shape is not None:
|
| 621 |
+
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
|
| 622 |
+
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
|
| 623 |
+
|
| 624 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 625 |
+
coordinates = 2 * coordinates - 1
|
| 626 |
+
coordinates = coordinates.to(self.positional_embedding.dtype)
|
| 627 |
+
coordinates = coordinates @ self.positional_embedding
|
| 628 |
+
coordinates = 2 * np.pi * coordinates
|
| 629 |
+
# outputs d_1 x ... x d_n x channel shape
|
| 630 |
+
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
class SamMaskEmbedding(nn.Module):
|
| 634 |
+
def __init__(self, config: SamPromptEncoderConfig):
|
| 635 |
+
super().__init__()
|
| 636 |
+
self.mask_input_channels = config.mask_input_channels // 4
|
| 637 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 638 |
+
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
|
| 639 |
+
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
|
| 640 |
+
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
|
| 641 |
+
self.layer_norm1 = SamLayerNorm(
|
| 642 |
+
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
|
| 643 |
+
)
|
| 644 |
+
self.layer_norm2 = SamLayerNorm(
|
| 645 |
+
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
def forward(self, masks):
|
| 649 |
+
hidden_states = self.conv1(masks)
|
| 650 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 651 |
+
hidden_states = self.activation(hidden_states)
|
| 652 |
+
|
| 653 |
+
hidden_states = self.conv2(hidden_states)
|
| 654 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 655 |
+
hidden_states = self.activation(hidden_states)
|
| 656 |
+
dense_embeddings = self.conv3(hidden_states)
|
| 657 |
+
return dense_embeddings
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class SamPromptEncoder(nn.Module):
|
| 661 |
+
def __init__(self, config: SamPromptEncoderConfig):
|
| 662 |
+
super().__init__()
|
| 663 |
+
self.shared_embedding = SamPositionalEmbedding(config.vision_config)
|
| 664 |
+
config = config.prompt_encoder_config
|
| 665 |
+
self.mask_embed = SamMaskEmbedding(config)
|
| 666 |
+
self.no_mask_embed = nn.Embedding(1, config.hidden_size)
|
| 667 |
+
|
| 668 |
+
self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
|
| 669 |
+
self.input_image_size = config.image_size
|
| 670 |
+
|
| 671 |
+
self.point_embed = nn.ModuleList(
|
| 672 |
+
[nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
|
| 673 |
+
)
|
| 674 |
+
self.hidden_size = config.hidden_size
|
| 675 |
+
self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
|
| 676 |
+
|
| 677 |
+
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
|
| 678 |
+
"""Embeds point prompts."""
|
| 679 |
+
points = points + 0.5 # Shift to center of pixel
|
| 680 |
+
if pad:
|
| 681 |
+
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
|
| 682 |
+
target_labels_shape = (points.shape[0], points.shape[1], 1)
|
| 683 |
+
padding_point = torch.zeros(target_point_shape, device=points.device)
|
| 684 |
+
padding_label = -torch.ones(target_labels_shape, device=labels.device)
|
| 685 |
+
points = torch.cat([points, padding_point], dim=2)
|
| 686 |
+
labels = torch.cat([labels, padding_label], dim=2)
|
| 687 |
+
input_shape = (self.input_image_size, self.input_image_size)
|
| 688 |
+
point_embedding = self.shared_embedding(points, input_shape)
|
| 689 |
+
|
| 690 |
+
# torch.where and expanding the labels tensor is required by the ONNX export
|
| 691 |
+
point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
|
| 692 |
+
|
| 693 |
+
# This is required for the ONNX export. The dtype, device need to be explicitely
|
| 694 |
+
# specificed as otherwise torch.onnx.export interprets as double
|
| 695 |
+
point_embedding = torch.where(
|
| 696 |
+
labels[..., None] != -10,
|
| 697 |
+
point_embedding,
|
| 698 |
+
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
point_embedding = torch.where(
|
| 702 |
+
(labels == 0)[:, :, :, None],
|
| 703 |
+
point_embedding + self.point_embed[0].weight[None, None, :, :],
|
| 704 |
+
point_embedding,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
point_embedding = torch.where(
|
| 708 |
+
(labels == 1)[:, :, :, None],
|
| 709 |
+
point_embedding + self.point_embed[1].weight[None, None, :, :],
|
| 710 |
+
point_embedding,
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
return point_embedding
|
| 714 |
+
|
| 715 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
| 716 |
+
"""Embeds box prompts."""
|
| 717 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
| 718 |
+
batch_size, nb_boxes = boxes.shape[:2]
|
| 719 |
+
coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
|
| 720 |
+
input_shape = (self.input_image_size, self.input_image_size)
|
| 721 |
+
corner_embedding = self.shared_embedding(coords, input_shape)
|
| 722 |
+
corner_embedding[:, :, 0, :] += self.point_embed[2].weight
|
| 723 |
+
corner_embedding[:, :, 1, :] += self.point_embed[3].weight
|
| 724 |
+
return corner_embedding
|
| 725 |
+
|
| 726 |
+
def forward(
|
| 727 |
+
self,
|
| 728 |
+
input_points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 729 |
+
input_labels: Optional[torch.Tensor],
|
| 730 |
+
input_boxes: Optional[torch.Tensor],
|
| 731 |
+
input_masks: Optional[torch.Tensor],
|
| 732 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 733 |
+
"""
|
| 734 |
+
Embeds different types of prompts, returning both sparse and dense embeddings.
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
points (`torch.Tensor`, *optional*):
|
| 738 |
+
point coordinates and labels to embed.
|
| 739 |
+
boxes (`torch.Tensor`, *optional*):
|
| 740 |
+
boxes to embed
|
| 741 |
+
masks (`torch.Tensor`, *optional*):
|
| 742 |
+
masks to embed
|
| 743 |
+
"""
|
| 744 |
+
sparse_embeddings = None
|
| 745 |
+
batch_size = 1
|
| 746 |
+
target_device = self.shared_embedding.positional_embedding.device
|
| 747 |
+
if input_points is not None:
|
| 748 |
+
batch_size, point_batch_size = input_points.shape[:2]
|
| 749 |
+
if input_labels is None:
|
| 750 |
+
raise ValueError("If points are provided, labels must also be provided.")
|
| 751 |
+
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
| 752 |
+
sparse_embeddings = point_embeddings
|
| 753 |
+
if input_boxes is not None:
|
| 754 |
+
batch_size = input_boxes.shape[0]
|
| 755 |
+
box_embeddings = self._embed_boxes(input_boxes)
|
| 756 |
+
if sparse_embeddings is None:
|
| 757 |
+
sparse_embeddings = box_embeddings
|
| 758 |
+
else:
|
| 759 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
|
| 760 |
+
if input_masks is not None:
|
| 761 |
+
dense_embeddings = self.mask_embed(input_masks)
|
| 762 |
+
else:
|
| 763 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
| 764 |
+
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
if sparse_embeddings is None:
|
| 768 |
+
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
|
| 769 |
+
|
| 770 |
+
return sparse_embeddings, dense_embeddings
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
class SamVisionAttention(nn.Module):
|
| 774 |
+
"""Multi-head Attention block with relative position embeddings."""
|
| 775 |
+
|
| 776 |
+
def __init__(self, config, window_size):
|
| 777 |
+
super().__init__()
|
| 778 |
+
input_size = (
|
| 779 |
+
(config.image_size // config.patch_size, config.image_size // config.patch_size)
|
| 780 |
+
if window_size == 0
|
| 781 |
+
else (window_size, window_size)
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
self.num_attention_heads = config.num_attention_heads
|
| 785 |
+
head_dim = config.hidden_size // config.num_attention_heads
|
| 786 |
+
self.scale = head_dim**-0.5
|
| 787 |
+
self.dropout = config.attention_dropout
|
| 788 |
+
|
| 789 |
+
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
|
| 790 |
+
self.proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 791 |
+
|
| 792 |
+
self.use_rel_pos = config.use_rel_pos
|
| 793 |
+
if self.use_rel_pos:
|
| 794 |
+
if input_size is None:
|
| 795 |
+
raise ValueError("Input size must be provided if using relative positional encoding.")
|
| 796 |
+
|
| 797 |
+
# initialize relative positional embeddings
|
| 798 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
| 799 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
| 800 |
+
|
| 801 |
+
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
| 802 |
+
"""
|
| 803 |
+
Get relative positional embeddings according to the relative positions of
|
| 804 |
+
query and key sizes.
|
| 805 |
+
|
| 806 |
+
Args:
|
| 807 |
+
q_size (int):
|
| 808 |
+
size of the query.
|
| 809 |
+
k_size (int):
|
| 810 |
+
size of key k.
|
| 811 |
+
rel_pos (`torch.Tensor`):
|
| 812 |
+
relative position embeddings (L, channel).
|
| 813 |
+
|
| 814 |
+
Returns:
|
| 815 |
+
Extracted positional embeddings according to relative positions.
|
| 816 |
+
"""
|
| 817 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
| 818 |
+
# Interpolate rel pos.
|
| 819 |
+
rel_pos_resized = F.interpolate(
|
| 820 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
| 821 |
+
size=max_rel_dist,
|
| 822 |
+
mode="linear",
|
| 823 |
+
)
|
| 824 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
| 825 |
+
|
| 826 |
+
# Scale the coords with short length if shapes for q and k are different.
|
| 827 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
| 828 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
| 829 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
| 830 |
+
|
| 831 |
+
return rel_pos_resized[relative_coords.long()]
|
| 832 |
+
|
| 833 |
+
def get_decomposed_rel_pos(
|
| 834 |
+
self,
|
| 835 |
+
query: torch.Tensor,
|
| 836 |
+
rel_pos_h: torch.Tensor,
|
| 837 |
+
rel_pos_w: torch.Tensor,
|
| 838 |
+
q_size: Tuple[int, int],
|
| 839 |
+
k_size: Tuple[int, int],
|
| 840 |
+
) -> torch.Tensor:
|
| 841 |
+
"""
|
| 842 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
| 843 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
| 844 |
+
|
| 845 |
+
Args:
|
| 846 |
+
query (`torch.Tensor`):
|
| 847 |
+
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
| 848 |
+
rel_pos_h (`torch.Tensor`):
|
| 849 |
+
relative position embeddings (Lh, channel) for height axis.
|
| 850 |
+
rel_pos_w (`torch.Tensor`):
|
| 851 |
+
relative position embeddings (Lw, channel) for width axis.
|
| 852 |
+
q_size (tuple):
|
| 853 |
+
spatial sequence size of query q with (query_height, query_width).
|
| 854 |
+
k_size (tuple):
|
| 855 |
+
spatial sequence size of key k with (key_height, key_width).
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
decomposed_rel_pos (`torch.Tensor`):
|
| 859 |
+
decomposed relative position embeddings.
|
| 860 |
+
"""
|
| 861 |
+
query_height, query_width = q_size
|
| 862 |
+
key_height, key_width = k_size
|
| 863 |
+
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
|
| 864 |
+
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
|
| 865 |
+
|
| 866 |
+
batch_size, _, dim = query.shape
|
| 867 |
+
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
|
| 868 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
| 869 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
| 870 |
+
|
| 871 |
+
decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
| 872 |
+
|
| 873 |
+
return decomposed_rel_pos
|
| 874 |
+
|
| 875 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
| 876 |
+
batch_size, height, width, _ = hidden_states.shape
|
| 877 |
+
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
| 878 |
+
qkv = (
|
| 879 |
+
self.qkv(hidden_states)
|
| 880 |
+
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
| 881 |
+
.permute(2, 0, 3, 1, 4)
|
| 882 |
+
)
|
| 883 |
+
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
| 884 |
+
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
| 885 |
+
|
| 886 |
+
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
| 887 |
+
|
| 888 |
+
if self.use_rel_pos:
|
| 889 |
+
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
| 890 |
+
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
| 891 |
+
)
|
| 892 |
+
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
|
| 893 |
+
attn_weights = attn_weights + decomposed_rel_pos
|
| 894 |
+
|
| 895 |
+
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
| 896 |
+
|
| 897 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 898 |
+
|
| 899 |
+
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
| 900 |
+
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
| 901 |
+
|
| 902 |
+
attn_output = self.proj(attn_output)
|
| 903 |
+
|
| 904 |
+
if output_attentions:
|
| 905 |
+
outputs = (attn_output, attn_weights)
|
| 906 |
+
else:
|
| 907 |
+
outputs = (attn_output, None)
|
| 908 |
+
|
| 909 |
+
return outputs
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
class SamVisionSdpaAttention(SamVisionAttention):
|
| 913 |
+
"""
|
| 914 |
+
Multi-head Attention block with relative position embeddings.
|
| 915 |
+
Using SDPA instead of the default attention.
|
| 916 |
+
"""
|
| 917 |
+
|
| 918 |
+
def __init__(self, config, window_size):
|
| 919 |
+
super().__init__(config, window_size)
|
| 920 |
+
|
| 921 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
| 922 |
+
if output_attentions:
|
| 923 |
+
logger.warning_once(
|
| 924 |
+
"`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
| 925 |
+
"`output_attentions=True`. Falling back to the manual attention implementation, but "
|
| 926 |
+
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
| 927 |
+
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 928 |
+
)
|
| 929 |
+
return super().forward(
|
| 930 |
+
hidden_states=hidden_states,
|
| 931 |
+
output_attentions=output_attentions,
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
batch_size, height, width, _ = hidden_states.shape
|
| 935 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
| 936 |
+
qkv = (
|
| 937 |
+
self.qkv(hidden_states)
|
| 938 |
+
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
| 939 |
+
.permute(2, 0, 3, 1, 4)
|
| 940 |
+
)
|
| 941 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
| 942 |
+
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
| 943 |
+
|
| 944 |
+
attn_bias = None
|
| 945 |
+
if self.use_rel_pos:
|
| 946 |
+
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
| 947 |
+
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
| 948 |
+
)
|
| 949 |
+
decomposed_rel_pos = decomposed_rel_pos.reshape(
|
| 950 |
+
batch_size, self.num_attention_heads, height * width, height * width
|
| 951 |
+
)
|
| 952 |
+
attn_bias = decomposed_rel_pos
|
| 953 |
+
|
| 954 |
+
query = query.view(batch_size, self.num_attention_heads, height * width, -1)
|
| 955 |
+
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
|
| 956 |
+
value = value.view(batch_size, self.num_attention_heads, height * width, -1)
|
| 957 |
+
|
| 958 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
|
| 959 |
+
|
| 960 |
+
attn_output = (
|
| 961 |
+
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
|
| 962 |
+
.permute(0, 2, 3, 1, 4)
|
| 963 |
+
.reshape(batch_size, height, width, -1)
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
attn_output = self.proj(attn_output)
|
| 967 |
+
|
| 968 |
+
return attn_output, None
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
SAM_VISION_ATTENTION_CLASSES = {
|
| 972 |
+
"eager": SamVisionAttention,
|
| 973 |
+
"sdpa": SamVisionSdpaAttention,
|
| 974 |
+
}
|
| 975 |
+
|
| 976 |
+
|
| 977 |
+
class SamVisionLayer(nn.Module):
|
| 978 |
+
def __init__(self, config, window_size):
|
| 979 |
+
super().__init__()
|
| 980 |
+
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 981 |
+
self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
|
| 982 |
+
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 983 |
+
self.mlp = SamMLPBlock(config)
|
| 984 |
+
self.window_size = window_size
|
| 985 |
+
|
| 986 |
+
def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
| 987 |
+
"""
|
| 988 |
+
Args:
|
| 989 |
+
Partition into non-overlapping windows with padding if needed.
|
| 990 |
+
hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
|
| 991 |
+
size.
|
| 992 |
+
|
| 993 |
+
Returns:
|
| 994 |
+
windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
|
| 995 |
+
(pad_height, pad_width): padded height and width before partition
|
| 996 |
+
"""
|
| 997 |
+
batch_size, height, width, channel = hidden_states.shape
|
| 998 |
+
|
| 999 |
+
pad_h = (window_size - height % window_size) % window_size
|
| 1000 |
+
pad_w = (window_size - width % window_size) % window_size
|
| 1001 |
+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
|
| 1002 |
+
pad_height, pad_width = height + pad_h, width + pad_w
|
| 1003 |
+
|
| 1004 |
+
hidden_states = hidden_states.reshape(
|
| 1005 |
+
batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
|
| 1006 |
+
)
|
| 1007 |
+
windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
|
| 1008 |
+
return windows, (pad_height, pad_width)
|
| 1009 |
+
|
| 1010 |
+
def window_unpartition(
|
| 1011 |
+
self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
|
| 1012 |
+
) -> torch.Tensor:
|
| 1013 |
+
"""
|
| 1014 |
+
Args:
|
| 1015 |
+
Window unpartition into original sequences and removing padding.
|
| 1016 |
+
hidden_states (tensor):
|
| 1017 |
+
input tokens with [batch_size * num_windows, window_size, window_size, channel].
|
| 1018 |
+
window_size (int):
|
| 1019 |
+
window size.
|
| 1020 |
+
padding_shape (Tuple):
|
| 1021 |
+
padded height and width (pad_height, pad_width).
|
| 1022 |
+
original_shape (Tuple): original height and width (height, width) before padding.
|
| 1023 |
+
|
| 1024 |
+
Returns:
|
| 1025 |
+
hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
|
| 1026 |
+
"""
|
| 1027 |
+
pad_height, pad_width = padding_shape
|
| 1028 |
+
height, width = original_shape
|
| 1029 |
+
batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
|
| 1030 |
+
hidden_states = windows.reshape(
|
| 1031 |
+
batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
|
| 1032 |
+
)
|
| 1033 |
+
hidden_states = (
|
| 1034 |
+
hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
hidden_states = hidden_states[:, :height, :width, :].contiguous()
|
| 1038 |
+
return hidden_states
|
| 1039 |
+
|
| 1040 |
+
def forward(
|
| 1041 |
+
self,
|
| 1042 |
+
hidden_states: torch.Tensor,
|
| 1043 |
+
output_attentions: Optional[bool] = False,
|
| 1044 |
+
) -> Tuple[torch.FloatTensor]:
|
| 1045 |
+
residual = hidden_states
|
| 1046 |
+
|
| 1047 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 1048 |
+
# Window partition
|
| 1049 |
+
if self.window_size > 0:
|
| 1050 |
+
height, width = hidden_states.shape[1], hidden_states.shape[2]
|
| 1051 |
+
hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
|
| 1052 |
+
|
| 1053 |
+
hidden_states, attn_weights = self.attn(
|
| 1054 |
+
hidden_states=hidden_states,
|
| 1055 |
+
output_attentions=output_attentions,
|
| 1056 |
+
)
|
| 1057 |
+
# Reverse window partition
|
| 1058 |
+
if self.window_size > 0:
|
| 1059 |
+
hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
|
| 1060 |
+
|
| 1061 |
+
hidden_states = residual + hidden_states
|
| 1062 |
+
layernorm_output = self.layer_norm2(hidden_states)
|
| 1063 |
+
hidden_states = hidden_states + self.mlp(layernorm_output)
|
| 1064 |
+
|
| 1065 |
+
outputs = (hidden_states,)
|
| 1066 |
+
if output_attentions:
|
| 1067 |
+
outputs += (attn_weights,)
|
| 1068 |
+
|
| 1069 |
+
return outputs
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
class SamVisionNeck(nn.Module):
|
| 1073 |
+
def __init__(self, config: SamVisionConfig):
|
| 1074 |
+
super().__init__()
|
| 1075 |
+
self.config = config
|
| 1076 |
+
|
| 1077 |
+
self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
|
| 1078 |
+
self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
|
| 1079 |
+
self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
|
| 1080 |
+
self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
|
| 1081 |
+
|
| 1082 |
+
def forward(self, hidden_states):
|
| 1083 |
+
hidden_states = hidden_states.permute(0, 3, 1, 2)
|
| 1084 |
+
hidden_states = self.conv1(hidden_states)
|
| 1085 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 1086 |
+
|
| 1087 |
+
hidden_states = self.conv2(hidden_states)
|
| 1088 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 1089 |
+
return hidden_states
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
class SamVisionEncoder(nn.Module):
|
| 1093 |
+
def __init__(self, config: SamVisionConfig):
|
| 1094 |
+
super().__init__()
|
| 1095 |
+
self.config = config
|
| 1096 |
+
self.image_size = config.image_size
|
| 1097 |
+
|
| 1098 |
+
self.patch_embed = SamPatchEmbeddings(config)
|
| 1099 |
+
|
| 1100 |
+
self.pos_embed = None
|
| 1101 |
+
if config.use_abs_pos:
|
| 1102 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 1103 |
+
self.pos_embed = nn.Parameter(
|
| 1104 |
+
torch.zeros(
|
| 1105 |
+
1,
|
| 1106 |
+
config.image_size // config.patch_size,
|
| 1107 |
+
config.image_size // config.patch_size,
|
| 1108 |
+
config.hidden_size,
|
| 1109 |
+
)
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
self.layers = nn.ModuleList()
|
| 1113 |
+
for i in range(config.num_hidden_layers):
|
| 1114 |
+
layer = SamVisionLayer(
|
| 1115 |
+
config,
|
| 1116 |
+
window_size=config.window_size if i not in config.global_attn_indexes else 0,
|
| 1117 |
+
)
|
| 1118 |
+
self.layers.append(layer)
|
| 1119 |
+
|
| 1120 |
+
self.neck = SamVisionNeck(config)
|
| 1121 |
+
|
| 1122 |
+
self.gradient_checkpointing = False
|
| 1123 |
+
|
| 1124 |
+
def get_input_embeddings(self):
|
| 1125 |
+
return self.patch_embed
|
| 1126 |
+
|
| 1127 |
+
@can_return_tuple
|
| 1128 |
+
def forward(
|
| 1129 |
+
self,
|
| 1130 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1131 |
+
output_attentions: Optional[bool] = None,
|
| 1132 |
+
output_hidden_states: Optional[bool] = None,
|
| 1133 |
+
) -> SamVisionEncoderOutput:
|
| 1134 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1135 |
+
output_hidden_states = (
|
| 1136 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
if pixel_values is None:
|
| 1140 |
+
raise ValueError("You have to specify pixel_values")
|
| 1141 |
+
|
| 1142 |
+
hidden_states = self.patch_embed(pixel_values)
|
| 1143 |
+
if self.pos_embed is not None:
|
| 1144 |
+
hidden_states = hidden_states + self.pos_embed
|
| 1145 |
+
|
| 1146 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1147 |
+
all_self_attentions = () if output_attentions else None
|
| 1148 |
+
|
| 1149 |
+
for i, layer_module in enumerate(self.layers):
|
| 1150 |
+
if output_hidden_states:
|
| 1151 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1152 |
+
|
| 1153 |
+
if self.gradient_checkpointing and self.training:
|
| 1154 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1155 |
+
layer_module.__call__,
|
| 1156 |
+
hidden_states,
|
| 1157 |
+
)
|
| 1158 |
+
else:
|
| 1159 |
+
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
|
| 1160 |
+
|
| 1161 |
+
hidden_states = layer_outputs[0]
|
| 1162 |
+
|
| 1163 |
+
if output_attentions:
|
| 1164 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 1165 |
+
|
| 1166 |
+
if output_hidden_states:
|
| 1167 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1168 |
+
|
| 1169 |
+
hidden_states = self.neck(hidden_states)
|
| 1170 |
+
|
| 1171 |
+
return SamVisionEncoderOutput(
|
| 1172 |
+
last_hidden_state=hidden_states,
|
| 1173 |
+
hidden_states=all_hidden_states,
|
| 1174 |
+
attentions=all_self_attentions,
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
class SamPreTrainedModel(PreTrainedModel):
|
| 1179 |
+
config_class = SamConfig
|
| 1180 |
+
base_model_prefix = "sam"
|
| 1181 |
+
main_input_name = "pixel_values"
|
| 1182 |
+
_no_split_modules = ["SamVisionAttention"]
|
| 1183 |
+
supports_gradient_checkpointing = True
|
| 1184 |
+
_supports_sdpa = True
|
| 1185 |
+
|
| 1186 |
+
def _init_weights(self, module):
|
| 1187 |
+
std = self.config.initializer_range
|
| 1188 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
| 1189 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1190 |
+
if module.bias is not None:
|
| 1191 |
+
module.bias.data.zero_()
|
| 1192 |
+
elif isinstance(module, nn.Embedding):
|
| 1193 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1194 |
+
if module.padding_idx is not None:
|
| 1195 |
+
module.weight.data[module.padding_idx].zero_()
|
| 1196 |
+
elif isinstance(module, (SamLayerNorm, nn.LayerNorm)):
|
| 1197 |
+
module.weight.data.fill_(1.0)
|
| 1198 |
+
module.bias.data.zero_()
|
| 1199 |
+
elif isinstance(module, SamVisionAttention):
|
| 1200 |
+
if module.use_rel_pos:
|
| 1201 |
+
module.rel_pos_h.data.zero_()
|
| 1202 |
+
module.rel_pos_w.data.zero_()
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
SAM_START_DOCSTRING = r"""
|
| 1206 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1207 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1208 |
+
etc.)
|
| 1209 |
+
|
| 1210 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 1211 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 1212 |
+
and behavior.
|
| 1213 |
+
|
| 1214 |
+
Parameters:
|
| 1215 |
+
config ([`SamConfig`]): Model configuration class with all the parameters of the model.
|
| 1216 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1217 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1218 |
+
"""
|
| 1219 |
+
|
| 1220 |
+
|
| 1221 |
+
SAM_INPUTS_DOCSTRING = r"""
|
| 1222 |
+
Args:
|
| 1223 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1224 |
+
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
|
| 1225 |
+
details.
|
| 1226 |
+
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
|
| 1227 |
+
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
|
| 1228 |
+
better results. The points can be obtained by passing a list of list of list to the processor that will
|
| 1229 |
+
create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
|
| 1230 |
+
second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
|
| 1231 |
+
per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
|
| 1232 |
+
multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
|
| 1233 |
+
coordinates of the point. If a different number of points is passed either for each image, or for each
|
| 1234 |
+
mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
|
| 1235 |
+
computation of the embedding will be skipped for these points using the labels.
|
| 1236 |
+
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
|
| 1237 |
+
Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
|
| 1238 |
+
official implementation, there are 3 types of labels
|
| 1239 |
+
|
| 1240 |
+
- `1`: the point is a point that contains the object of interest
|
| 1241 |
+
- `0`: the point is a point that does not contain the object of interest
|
| 1242 |
+
- `-1`: the point corresponds to the background
|
| 1243 |
+
|
| 1244 |
+
We added the label:
|
| 1245 |
+
|
| 1246 |
+
- `-10`: the point is a padding point, thus should be ignored by the prompt encoder
|
| 1247 |
+
|
| 1248 |
+
The padding labels should be automatically done by the processor.
|
| 1249 |
+
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
|
| 1250 |
+
Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
|
| 1251 |
+
much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
|
| 1252 |
+
that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
|
| 1253 |
+
size, the number of boxes per image and the coordinates of the top left and botton right point of the box.
|
| 1254 |
+
In the order (`x1`, `y1`, `x2`, `y2`):
|
| 1255 |
+
|
| 1256 |
+
- `x1`: the x coordinate of the top left point of the input box
|
| 1257 |
+
- `y1`: the y coordinate of the top left point of the input box
|
| 1258 |
+
- `x2`: the x coordinate of the bottom right point of the input box
|
| 1259 |
+
- `y2`: the y coordinate of the bottom right point of the input box
|
| 1260 |
+
|
| 1261 |
+
input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
|
| 1262 |
+
SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
|
| 1263 |
+
generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
|
| 1264 |
+
manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
|
| 1265 |
+
|
| 1266 |
+
image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
|
| 1267 |
+
Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
|
| 1268 |
+
efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
|
| 1269 |
+
method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
|
| 1270 |
+
multimask_output (`bool`, *optional*):
|
| 1271 |
+
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
|
| 1272 |
+
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
|
| 1273 |
+
"best" mask, by specifying `multimask_output=False`.
|
| 1274 |
+
attention_similarity (`torch.FloatTensor`, *optional*):
|
| 1275 |
+
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
|
| 1276 |
+
model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
|
| 1277 |
+
target_embedding (`torch.FloatTensor`, *optional*):
|
| 1278 |
+
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
|
| 1279 |
+
the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
|
| 1280 |
+
output_attentions (`bool`, *optional*):
|
| 1281 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1282 |
+
tensors for more detail.
|
| 1283 |
+
output_hidden_states (`bool`, *optional*):
|
| 1284 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1285 |
+
more detail.
|
| 1286 |
+
return_dict (`bool`, *optional*):
|
| 1287 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1288 |
+
"""
|
| 1289 |
+
|
| 1290 |
+
|
| 1291 |
+
SAM_VISION_INPUTS_DOCSTRING = r"""
|
| 1292 |
+
Args:
|
| 1293 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1294 |
+
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
|
| 1295 |
+
details.
|
| 1296 |
+
output_attentions (`bool`, *optional*):
|
| 1297 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1298 |
+
tensors for more detail.
|
| 1299 |
+
output_hidden_states (`bool`, *optional*):
|
| 1300 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1301 |
+
more detail.
|
| 1302 |
+
return_dict (`bool`, *optional*):
|
| 1303 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1304 |
+
"""
|
| 1305 |
+
|
| 1306 |
+
|
| 1307 |
+
@add_start_docstrings(
|
| 1308 |
+
"""The vision model from Sam without any head or projection on top.""",
|
| 1309 |
+
SAM_START_DOCSTRING,
|
| 1310 |
+
)
|
| 1311 |
+
class SamVisionModel(SamPreTrainedModel):
|
| 1312 |
+
config_class = SamVisionConfig
|
| 1313 |
+
main_input_name = "pixel_values"
|
| 1314 |
+
|
| 1315 |
+
def __init__(self, config: SamVisionConfig):
|
| 1316 |
+
super().__init__(config)
|
| 1317 |
+
self.vision_encoder = SamVisionEncoder(config)
|
| 1318 |
+
|
| 1319 |
+
# Initialize weights and apply final processing
|
| 1320 |
+
self.post_init()
|
| 1321 |
+
|
| 1322 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 1323 |
+
return self.vision_encoder.patch_embed
|
| 1324 |
+
|
| 1325 |
+
@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
|
| 1326 |
+
@replace_return_docstrings(output_type=SamVisionEncoderOutput, config_class=SamVisionConfig)
|
| 1327 |
+
def forward(
|
| 1328 |
+
self,
|
| 1329 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1330 |
+
output_attentions: Optional[bool] = None,
|
| 1331 |
+
output_hidden_states: Optional[bool] = None,
|
| 1332 |
+
return_dict: Optional[bool] = None,
|
| 1333 |
+
) -> Union[Tuple, SamVisionEncoderOutput]:
|
| 1334 |
+
r"""
|
| 1335 |
+
Returns:
|
| 1336 |
+
|
| 1337 |
+
"""
|
| 1338 |
+
return self.vision_encoder(
|
| 1339 |
+
pixel_values,
|
| 1340 |
+
output_attentions=output_attentions,
|
| 1341 |
+
output_hidden_states=output_hidden_states,
|
| 1342 |
+
return_dict=return_dict,
|
| 1343 |
+
)
|
| 1344 |
+
|
| 1345 |
+
|
| 1346 |
+
@add_start_docstrings(
|
| 1347 |
+
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
|
| 1348 |
+
" optional 2D location and bounding boxes.",
|
| 1349 |
+
SAM_START_DOCSTRING,
|
| 1350 |
+
)
|
| 1351 |
+
class SamModel(SamPreTrainedModel):
|
| 1352 |
+
_tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
|
| 1353 |
+
# need to be ignored, as it's a buffer and will not be correctly detected as tied weight
|
| 1354 |
+
_keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
|
| 1355 |
+
|
| 1356 |
+
def __init__(self, config):
|
| 1357 |
+
super().__init__(config)
|
| 1358 |
+
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
|
| 1359 |
+
|
| 1360 |
+
self.vision_encoder = SamVisionEncoder(config.vision_config)
|
| 1361 |
+
self.prompt_encoder = SamPromptEncoder(config)
|
| 1362 |
+
self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
|
| 1363 |
+
|
| 1364 |
+
self.post_init()
|
| 1365 |
+
|
| 1366 |
+
def _tie_weights(self):
|
| 1367 |
+
self.prompt_encoder.shared_embedding.positional_embedding.data = (
|
| 1368 |
+
self.shared_image_embedding.positional_embedding.data
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
def get_input_embeddings(self):
|
| 1372 |
+
return self.vision_encoder.get_input_embeddings()
|
| 1373 |
+
|
| 1374 |
+
def get_image_wide_positional_embeddings(self):
|
| 1375 |
+
size = self.config.prompt_encoder_config.image_embedding_size
|
| 1376 |
+
target_device = self.shared_image_embedding.positional_embedding.device
|
| 1377 |
+
target_dtype = self.shared_image_embedding.positional_embedding.dtype
|
| 1378 |
+
grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
|
| 1379 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 1380 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 1381 |
+
y_embed = y_embed / size
|
| 1382 |
+
x_embed = x_embed / size
|
| 1383 |
+
|
| 1384 |
+
positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
|
| 1385 |
+
return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
|
| 1386 |
+
|
| 1387 |
+
@torch.no_grad()
|
| 1388 |
+
def get_image_embeddings(
|
| 1389 |
+
self,
|
| 1390 |
+
pixel_values,
|
| 1391 |
+
output_attentions: Optional[bool] = None,
|
| 1392 |
+
output_hidden_states: Optional[bool] = None,
|
| 1393 |
+
):
|
| 1394 |
+
r"""
|
| 1395 |
+
Returns the image embeddings by passing the pixel values through the vision encoder.
|
| 1396 |
+
|
| 1397 |
+
Args:
|
| 1398 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1399 |
+
Input pixel values
|
| 1400 |
+
output_attentions (`bool`, *optional*):
|
| 1401 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 1402 |
+
output_hidden_states (`bool`, *optional*):
|
| 1403 |
+
Whether or not to return the hidden states of all layers.
|
| 1404 |
+
"""
|
| 1405 |
+
vision_output = self.vision_encoder(
|
| 1406 |
+
pixel_values,
|
| 1407 |
+
output_attentions=output_attentions,
|
| 1408 |
+
output_hidden_states=output_hidden_states,
|
| 1409 |
+
)
|
| 1410 |
+
image_embeddings = vision_output[0]
|
| 1411 |
+
return image_embeddings
|
| 1412 |
+
|
| 1413 |
+
@torch.no_grad()
|
| 1414 |
+
def get_prompt_embeddings(
|
| 1415 |
+
self,
|
| 1416 |
+
input_points: Optional[torch.FloatTensor] = None,
|
| 1417 |
+
input_labels: Optional[torch.LongTensor] = None,
|
| 1418 |
+
input_boxes: Optional[torch.FloatTensor] = None,
|
| 1419 |
+
input_masks: Optional[torch.LongTensor] = None,
|
| 1420 |
+
):
|
| 1421 |
+
r"""
|
| 1422 |
+
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
|
| 1423 |
+
|
| 1424 |
+
Args:
|
| 1425 |
+
input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
|
| 1426 |
+
Optional input points for the prompt encoder. The padding of the point is automatically done by the
|
| 1427 |
+
processor. `point_batch_size` refers to the number of masks that we want the model to predict per
|
| 1428 |
+
point. The model will output `point_batch_size` times 3 masks in total.
|
| 1429 |
+
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
|
| 1430 |
+
Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
|
| 1431 |
+
processor, or can be fed by the user.
|
| 1432 |
+
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
|
| 1433 |
+
Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
|
| 1434 |
+
processor. users can also pass manually the input boxes.
|
| 1435 |
+
input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
|
| 1436 |
+
Optional input masks for the prompt encoder.
|
| 1437 |
+
"""
|
| 1438 |
+
prompt_output = self.prompt_encoder(
|
| 1439 |
+
input_points=input_points,
|
| 1440 |
+
input_labels=input_labels,
|
| 1441 |
+
input_boxes=input_boxes,
|
| 1442 |
+
input_masks=input_masks,
|
| 1443 |
+
)
|
| 1444 |
+
return prompt_output
|
| 1445 |
+
|
| 1446 |
+
@can_return_tuple
|
| 1447 |
+
@add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
|
| 1448 |
+
def forward(
|
| 1449 |
+
self,
|
| 1450 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1451 |
+
input_points: Optional[torch.FloatTensor] = None,
|
| 1452 |
+
input_labels: Optional[torch.LongTensor] = None,
|
| 1453 |
+
input_boxes: Optional[torch.FloatTensor] = None,
|
| 1454 |
+
input_masks: Optional[torch.LongTensor] = None,
|
| 1455 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
| 1456 |
+
multimask_output: bool = True,
|
| 1457 |
+
attention_similarity: Optional[torch.FloatTensor] = None,
|
| 1458 |
+
target_embedding: Optional[torch.FloatTensor] = None,
|
| 1459 |
+
output_attentions: Optional[bool] = None,
|
| 1460 |
+
output_hidden_states: Optional[bool] = None,
|
| 1461 |
+
**kwargs,
|
| 1462 |
+
) -> SamImageSegmentationOutput:
|
| 1463 |
+
r"""
|
| 1464 |
+
Example:
|
| 1465 |
+
|
| 1466 |
+
```python
|
| 1467 |
+
>>> from PIL import Image
|
| 1468 |
+
>>> import requests
|
| 1469 |
+
>>> from transformers import AutoModel, AutoProcessor
|
| 1470 |
+
|
| 1471 |
+
>>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
|
| 1472 |
+
>>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
|
| 1473 |
+
|
| 1474 |
+
>>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
|
| 1475 |
+
>>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
| 1476 |
+
>>> input_points = [[[400, 650]]] # 2D location of a window on the car
|
| 1477 |
+
>>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
|
| 1478 |
+
|
| 1479 |
+
>>> # Get segmentation mask
|
| 1480 |
+
>>> outputs = model(**inputs)
|
| 1481 |
+
|
| 1482 |
+
>>> # Postprocess masks
|
| 1483 |
+
>>> masks = processor.post_process_masks(
|
| 1484 |
+
... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
|
| 1485 |
+
... )
|
| 1486 |
+
```
|
| 1487 |
+
"""
|
| 1488 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1489 |
+
output_hidden_states = (
|
| 1490 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1491 |
+
)
|
| 1492 |
+
|
| 1493 |
+
if pixel_values is None and image_embeddings is None:
|
| 1494 |
+
raise ValueError("Either pixel_values or image_embeddings must be provided.")
|
| 1495 |
+
|
| 1496 |
+
if pixel_values is not None and image_embeddings is not None:
|
| 1497 |
+
raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
|
| 1498 |
+
|
| 1499 |
+
if input_points is not None and len(input_points.shape) != 4:
|
| 1500 |
+
raise ValueError(
|
| 1501 |
+
"The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
|
| 1502 |
+
" got {}.".format(input_points.shape),
|
| 1503 |
+
)
|
| 1504 |
+
if input_boxes is not None and len(input_boxes.shape) != 3:
|
| 1505 |
+
raise ValueError(
|
| 1506 |
+
"The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
|
| 1507 |
+
" got {}.".format(input_boxes.shape),
|
| 1508 |
+
)
|
| 1509 |
+
if input_points is not None and input_boxes is not None:
|
| 1510 |
+
point_batch_size = input_points.shape[1]
|
| 1511 |
+
box_batch_size = input_boxes.shape[1]
|
| 1512 |
+
if point_batch_size != box_batch_size:
|
| 1513 |
+
raise ValueError(
|
| 1514 |
+
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
|
| 1515 |
+
point_batch_size, box_batch_size
|
| 1516 |
+
)
|
| 1517 |
+
)
|
| 1518 |
+
|
| 1519 |
+
image_positional_embeddings = self.get_image_wide_positional_embeddings()
|
| 1520 |
+
# repeat with batch size
|
| 1521 |
+
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
|
| 1522 |
+
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
|
| 1523 |
+
|
| 1524 |
+
vision_attentions = None
|
| 1525 |
+
vision_hidden_states = None
|
| 1526 |
+
|
| 1527 |
+
if pixel_values is not None:
|
| 1528 |
+
vision_outputs: SamVisionEncoderOutput = self.vision_encoder(
|
| 1529 |
+
pixel_values,
|
| 1530 |
+
output_attentions=output_attentions,
|
| 1531 |
+
output_hidden_states=output_hidden_states,
|
| 1532 |
+
)
|
| 1533 |
+
image_embeddings = vision_outputs.last_hidden_state
|
| 1534 |
+
|
| 1535 |
+
if output_hidden_states:
|
| 1536 |
+
vision_hidden_states = vision_outputs.hidden_states
|
| 1537 |
+
if output_attentions:
|
| 1538 |
+
vision_attentions = vision_outputs.attentions
|
| 1539 |
+
|
| 1540 |
+
if input_points is not None and input_labels is None:
|
| 1541 |
+
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
|
| 1542 |
+
|
| 1543 |
+
if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
|
| 1544 |
+
raise ValueError(
|
| 1545 |
+
"The batch size of the image embeddings and the input points must be the same. ",
|
| 1546 |
+
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
|
| 1547 |
+
" if you want to pass multiple points for the same image, make sure that you passed ",
|
| 1548 |
+
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
|
| 1549 |
+
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
| 1550 |
+
)
|
| 1551 |
+
|
| 1552 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 1553 |
+
input_points=input_points,
|
| 1554 |
+
input_labels=input_labels,
|
| 1555 |
+
input_boxes=input_boxes,
|
| 1556 |
+
input_masks=input_masks,
|
| 1557 |
+
)
|
| 1558 |
+
|
| 1559 |
+
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
|
| 1560 |
+
image_embeddings=image_embeddings,
|
| 1561 |
+
image_positional_embeddings=image_positional_embeddings,
|
| 1562 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 1563 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 1564 |
+
multimask_output=multimask_output,
|
| 1565 |
+
attention_similarity=attention_similarity,
|
| 1566 |
+
target_embedding=target_embedding,
|
| 1567 |
+
output_attentions=output_attentions,
|
| 1568 |
+
)
|
| 1569 |
+
|
| 1570 |
+
return SamImageSegmentationOutput(
|
| 1571 |
+
iou_scores=iou_predictions,
|
| 1572 |
+
pred_masks=low_res_masks,
|
| 1573 |
+
vision_hidden_states=vision_hidden_states,
|
| 1574 |
+
vision_attentions=vision_attentions,
|
| 1575 |
+
mask_decoder_attentions=mask_decoder_attentions,
|
| 1576 |
+
)
|
| 1577 |
+
|
| 1578 |
+
|
| 1579 |
+
__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/sam/modeling_tf_sam.py
ADDED
|
@@ -0,0 +1,1726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
|
| 17 |
+
discrepancy, the original file should be regarded as the 'reference' version.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import collections
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from typing import Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import tensorflow as tf
|
| 28 |
+
|
| 29 |
+
from ...activations_tf import ACT2FN
|
| 30 |
+
from ...modeling_tf_outputs import TFBaseModelOutput
|
| 31 |
+
from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
|
| 32 |
+
from ...tf_utils import flatten, functional_layernorm
|
| 33 |
+
from ...utils import (
|
| 34 |
+
ModelOutput,
|
| 35 |
+
add_start_docstrings,
|
| 36 |
+
add_start_docstrings_to_model_forward,
|
| 37 |
+
logging,
|
| 38 |
+
replace_return_docstrings,
|
| 39 |
+
)
|
| 40 |
+
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
_CONFIG_FOR_DOC = "SamConfig"
|
| 46 |
+
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class TFSamVisionEncoderOutput(ModelOutput):
|
| 51 |
+
"""
|
| 52 |
+
Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
|
| 53 |
+
layer to the pooler_output.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 57 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
| 58 |
+
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 59 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 60 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 61 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
|
| 62 |
+
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 63 |
+
|
| 64 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 65 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 66 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 67 |
+
sequence_length)`.
|
| 68 |
+
|
| 69 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 70 |
+
heads.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
image_embeds: tf.Tensor | None = None
|
| 74 |
+
last_hidden_state: Optional[tf.Tensor] = None
|
| 75 |
+
hidden_states: Tuple[tf.Tensor, ...] | None = None
|
| 76 |
+
attentions: Tuple[tf.Tensor, ...] | None = None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass
|
| 80 |
+
class TFSamImageSegmentationOutput(ModelOutput):
|
| 81 |
+
"""
|
| 82 |
+
Base class for Segment-Anything model's output
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
|
| 86 |
+
The iou scores of the predicted masks.
|
| 87 |
+
pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
|
| 88 |
+
The predicted low resolutions masks. Needs to be post-processed by the processor
|
| 89 |
+
vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 90 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
|
| 91 |
+
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 92 |
+
|
| 93 |
+
Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
|
| 94 |
+
vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 95 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 96 |
+
sequence_length)`.
|
| 97 |
+
|
| 98 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 99 |
+
heads.
|
| 100 |
+
mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 101 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 102 |
+
sequence_length)`.
|
| 103 |
+
|
| 104 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 105 |
+
heads.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
iou_scores: Optional[tf.Tensor] = None
|
| 109 |
+
pred_masks: Optional[tf.Tensor] = None
|
| 110 |
+
vision_hidden_states: Tuple[tf.Tensor, ...] | None = None
|
| 111 |
+
vision_attentions: Tuple[tf.Tensor, ...] | None = None
|
| 112 |
+
mask_decoder_attentions: Tuple[tf.Tensor, ...] | None = None
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class TFSamPatchEmbeddings(keras.layers.Layer):
|
| 116 |
+
"""
|
| 117 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 118 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 119 |
+
Transformer.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, config, **kwargs):
|
| 123 |
+
super().__init__(**kwargs)
|
| 124 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 125 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 126 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 127 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 128 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 129 |
+
self.image_size = image_size
|
| 130 |
+
self.patch_size = patch_size
|
| 131 |
+
self.num_channels = num_channels
|
| 132 |
+
self.num_patches = num_patches
|
| 133 |
+
|
| 134 |
+
self.projection = keras.layers.Conv2D(
|
| 135 |
+
hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def call(self, pixel_values):
|
| 139 |
+
batch_size, num_channels, height, width = shape_list(pixel_values)
|
| 140 |
+
if num_channels != self.num_channels:
|
| 141 |
+
raise ValueError(
|
| 142 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 143 |
+
)
|
| 144 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 145 |
+
raise ValueError(
|
| 146 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
| 147 |
+
)
|
| 148 |
+
embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
|
| 149 |
+
return embeddings
|
| 150 |
+
|
| 151 |
+
def build(self, input_shape=None):
|
| 152 |
+
if self.built:
|
| 153 |
+
return
|
| 154 |
+
self.built = True
|
| 155 |
+
if getattr(self, "projection", None) is not None:
|
| 156 |
+
with tf.name_scope(self.projection.name):
|
| 157 |
+
self.projection.build([None, None, None, self.num_channels])
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class TFSamMLPBlock(keras.layers.Layer):
|
| 161 |
+
def __init__(self, config, **kwargs):
|
| 162 |
+
super().__init__(**kwargs)
|
| 163 |
+
self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1")
|
| 164 |
+
self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2")
|
| 165 |
+
self.act = ACT2FN[config.hidden_act]
|
| 166 |
+
self.config = config
|
| 167 |
+
|
| 168 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 169 |
+
hidden_states = self.lin1(hidden_states)
|
| 170 |
+
hidden_states = self.act(hidden_states)
|
| 171 |
+
hidden_states = self.lin2(hidden_states)
|
| 172 |
+
return hidden_states
|
| 173 |
+
|
| 174 |
+
def build(self, input_shape=None):
|
| 175 |
+
if self.built:
|
| 176 |
+
return
|
| 177 |
+
self.built = True
|
| 178 |
+
if getattr(self, "lin1", None) is not None:
|
| 179 |
+
with tf.name_scope(self.lin1.name):
|
| 180 |
+
self.lin1.build([None, None, self.config.hidden_size])
|
| 181 |
+
if getattr(self, "lin2", None) is not None:
|
| 182 |
+
with tf.name_scope(self.lin2.name):
|
| 183 |
+
self.lin2.build([None, None, self.config.mlp_dim])
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class TFSamLayerNorm(keras.layers.Layer):
|
| 187 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 188 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
|
| 189 |
+
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
|
| 193 |
+
super().__init__(**kwargs)
|
| 194 |
+
self.eps = eps
|
| 195 |
+
self.data_format = data_format
|
| 196 |
+
self.normalized_shape = normalized_shape
|
| 197 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 198 |
+
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
|
| 199 |
+
|
| 200 |
+
def build(self, input_shape):
|
| 201 |
+
self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
|
| 202 |
+
self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
|
| 203 |
+
super().build(input_shape)
|
| 204 |
+
|
| 205 |
+
def call(self, x: tf.Tensor) -> tf.Tensor:
|
| 206 |
+
if self.data_format == "channels_last":
|
| 207 |
+
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
|
| 208 |
+
elif self.data_format == "channels_first":
|
| 209 |
+
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
|
| 210 |
+
return x
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class TFSamAttention(keras.layers.Layer):
|
| 214 |
+
"""
|
| 215 |
+
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
| 216 |
+
values.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
def __init__(self, config, downsample_rate=None, **kwargs):
|
| 220 |
+
super().__init__(**kwargs)
|
| 221 |
+
self.hidden_size = config.hidden_size
|
| 222 |
+
|
| 223 |
+
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
|
| 224 |
+
|
| 225 |
+
self.internal_dim = config.hidden_size // downsample_rate
|
| 226 |
+
self.num_attention_heads = config.num_attention_heads
|
| 227 |
+
if self.internal_dim % config.num_attention_heads != 0:
|
| 228 |
+
raise ValueError("num_attention_heads must divide hidden_size.")
|
| 229 |
+
|
| 230 |
+
self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj")
|
| 231 |
+
self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj")
|
| 232 |
+
self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj")
|
| 233 |
+
self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj")
|
| 234 |
+
|
| 235 |
+
def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
|
| 236 |
+
batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
|
| 237 |
+
c_per_head = channel // num_attention_heads
|
| 238 |
+
hidden_states = tf.reshape(
|
| 239 |
+
hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
|
| 240 |
+
)
|
| 241 |
+
return tf.transpose(hidden_states, perm=[0, 2, 1, 3])
|
| 242 |
+
|
| 243 |
+
def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
|
| 244 |
+
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
|
| 245 |
+
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
|
| 246 |
+
return tf.reshape(
|
| 247 |
+
hidden_states,
|
| 248 |
+
(batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
|
| 252 |
+
# Input projections
|
| 253 |
+
query = self.q_proj(query)
|
| 254 |
+
key = self.k_proj(key)
|
| 255 |
+
value = self.v_proj(value)
|
| 256 |
+
|
| 257 |
+
point_batch_size = shape_list(query)[1]
|
| 258 |
+
# Separate into heads
|
| 259 |
+
query = self._separate_heads(query, self.num_attention_heads)
|
| 260 |
+
key = self._separate_heads(key, self.num_attention_heads)
|
| 261 |
+
value = self._separate_heads(value, self.num_attention_heads)
|
| 262 |
+
|
| 263 |
+
# SamAttention
|
| 264 |
+
_, _, _, c_per_head = shape_list(query)
|
| 265 |
+
attn = tf.matmul(
|
| 266 |
+
query, tf.transpose(key, perm=[0, 1, 3, 2])
|
| 267 |
+
) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
|
| 268 |
+
attn = attn / tf.math.sqrt(float(c_per_head))
|
| 269 |
+
attn = tf.nn.softmax(attn, axis=-1)
|
| 270 |
+
|
| 271 |
+
# Get output
|
| 272 |
+
out = tf.matmul(attn, value)
|
| 273 |
+
out = self._recombine_heads(out, point_batch_size)
|
| 274 |
+
out = self.out_proj(out)
|
| 275 |
+
|
| 276 |
+
return out
|
| 277 |
+
|
| 278 |
+
def build(self, input_shape=None):
|
| 279 |
+
if self.built:
|
| 280 |
+
return
|
| 281 |
+
self.built = True
|
| 282 |
+
if getattr(self, "q_proj", None) is not None:
|
| 283 |
+
with tf.name_scope(self.q_proj.name):
|
| 284 |
+
self.q_proj.build([None, None, self.hidden_size])
|
| 285 |
+
if getattr(self, "k_proj", None) is not None:
|
| 286 |
+
with tf.name_scope(self.k_proj.name):
|
| 287 |
+
self.k_proj.build([None, None, self.hidden_size])
|
| 288 |
+
if getattr(self, "v_proj", None) is not None:
|
| 289 |
+
with tf.name_scope(self.v_proj.name):
|
| 290 |
+
self.v_proj.build([None, None, self.hidden_size])
|
| 291 |
+
if getattr(self, "out_proj", None) is not None:
|
| 292 |
+
with tf.name_scope(self.out_proj.name):
|
| 293 |
+
self.out_proj.build([None, None, self.internal_dim])
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class TFSamTwoWayAttentionBlock(keras.layers.Layer):
|
| 297 |
+
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
|
| 298 |
+
"""
|
| 299 |
+
A transformer block with four layers:
|
| 300 |
+
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
|
| 301 |
+
sparse inputs (4) cross attention of dense inputs -> sparse inputs
|
| 302 |
+
|
| 303 |
+
Arguments:
|
| 304 |
+
config (`SamMaskDecoderConfig`):
|
| 305 |
+
The configuration file used to instantiate the block
|
| 306 |
+
attention_downsample_rate (*optionalk*, int, defaults to 2):
|
| 307 |
+
The downsample ratio of the block used to reduce the inner dim of the attention.
|
| 308 |
+
skip_first_layer_pe (*optional*, bool, defaults to `False`):
|
| 309 |
+
Whether or not to skip the addition of the query_point_embedding on the first layer.
|
| 310 |
+
"""
|
| 311 |
+
super().__init__(**kwargs)
|
| 312 |
+
|
| 313 |
+
self.hidden_size = config.hidden_size
|
| 314 |
+
self.layer_norm_eps = config.layer_norm_eps
|
| 315 |
+
|
| 316 |
+
self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
|
| 317 |
+
self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")
|
| 318 |
+
|
| 319 |
+
self.cross_attn_token_to_image = TFSamAttention(
|
| 320 |
+
config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
|
| 321 |
+
)
|
| 322 |
+
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")
|
| 323 |
+
|
| 324 |
+
self.mlp = TFSamMLPBlock(config, name="mlp")
|
| 325 |
+
self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")
|
| 326 |
+
|
| 327 |
+
self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
|
| 328 |
+
self.cross_attn_image_to_token = TFSamAttention(
|
| 329 |
+
config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 333 |
+
|
| 334 |
+
def call(
|
| 335 |
+
self,
|
| 336 |
+
queries: tf.Tensor,
|
| 337 |
+
keys: tf.Tensor,
|
| 338 |
+
query_point_embedding: tf.Tensor,
|
| 339 |
+
key_point_embedding: tf.Tensor,
|
| 340 |
+
output_attentions: bool = False,
|
| 341 |
+
):
|
| 342 |
+
# Self attention block
|
| 343 |
+
if self.skip_first_layer_pe:
|
| 344 |
+
queries = self.self_attn(query=queries, key=queries, value=queries)
|
| 345 |
+
else:
|
| 346 |
+
query = queries + query_point_embedding
|
| 347 |
+
attn_out = self.self_attn(query=query, key=query, value=queries)
|
| 348 |
+
queries = queries + attn_out
|
| 349 |
+
queries = self.layer_norm1(queries)
|
| 350 |
+
|
| 351 |
+
# Cross attention block, tokens attending to image embedding
|
| 352 |
+
query = queries + query_point_embedding
|
| 353 |
+
key = keys + key_point_embedding
|
| 354 |
+
|
| 355 |
+
attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
|
| 356 |
+
queries = queries + attn_out
|
| 357 |
+
|
| 358 |
+
queries = self.layer_norm2(queries)
|
| 359 |
+
|
| 360 |
+
# MLP block
|
| 361 |
+
mlp_out = self.mlp(queries)
|
| 362 |
+
queries = queries + mlp_out
|
| 363 |
+
queries = self.layer_norm3(queries)
|
| 364 |
+
|
| 365 |
+
# Cross attention block, image embedding attending to tokens
|
| 366 |
+
query = queries + query_point_embedding
|
| 367 |
+
key = keys + key_point_embedding
|
| 368 |
+
|
| 369 |
+
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
|
| 370 |
+
keys = keys + attn_out
|
| 371 |
+
|
| 372 |
+
keys = self.layer_norm4(keys)
|
| 373 |
+
|
| 374 |
+
outputs = (queries, keys)
|
| 375 |
+
|
| 376 |
+
if output_attentions:
|
| 377 |
+
outputs = outputs + (attn_out,)
|
| 378 |
+
else:
|
| 379 |
+
outputs = outputs + (None,)
|
| 380 |
+
|
| 381 |
+
return outputs
|
| 382 |
+
|
| 383 |
+
def build(self, input_shape=None):
|
| 384 |
+
if self.built:
|
| 385 |
+
return
|
| 386 |
+
self.built = True
|
| 387 |
+
if getattr(self, "self_attn", None) is not None:
|
| 388 |
+
with tf.name_scope(self.self_attn.name):
|
| 389 |
+
self.self_attn.build(None)
|
| 390 |
+
if getattr(self, "layer_norm1", None) is not None:
|
| 391 |
+
with tf.name_scope(self.layer_norm1.name):
|
| 392 |
+
self.layer_norm1.build([None, None, None, self.hidden_size])
|
| 393 |
+
if getattr(self, "cross_attn_token_to_image", None) is not None:
|
| 394 |
+
with tf.name_scope(self.cross_attn_token_to_image.name):
|
| 395 |
+
self.cross_attn_token_to_image.build(None)
|
| 396 |
+
if getattr(self, "layer_norm2", None) is not None:
|
| 397 |
+
with tf.name_scope(self.layer_norm2.name):
|
| 398 |
+
self.layer_norm2.build([None, None, None, self.hidden_size])
|
| 399 |
+
if getattr(self, "mlp", None) is not None:
|
| 400 |
+
with tf.name_scope(self.mlp.name):
|
| 401 |
+
self.mlp.build(None)
|
| 402 |
+
if getattr(self, "layer_norm3", None) is not None:
|
| 403 |
+
with tf.name_scope(self.layer_norm3.name):
|
| 404 |
+
self.layer_norm3.build([None, None, None, self.hidden_size])
|
| 405 |
+
if getattr(self, "layer_norm4", None) is not None:
|
| 406 |
+
with tf.name_scope(self.layer_norm4.name):
|
| 407 |
+
self.layer_norm4.build([None, None, None, self.hidden_size])
|
| 408 |
+
if getattr(self, "cross_attn_image_to_token", None) is not None:
|
| 409 |
+
with tf.name_scope(self.cross_attn_image_to_token.name):
|
| 410 |
+
self.cross_attn_image_to_token.build(None)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class TFSamTwoWayTransformer(keras.layers.Layer):
|
| 414 |
+
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
|
| 415 |
+
super().__init__(**kwargs)
|
| 416 |
+
self.config = config
|
| 417 |
+
|
| 418 |
+
self.num_hidden_layers = config.num_hidden_layers
|
| 419 |
+
self.layers = []
|
| 420 |
+
|
| 421 |
+
for i in range(self.num_hidden_layers):
|
| 422 |
+
self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))
|
| 423 |
+
|
| 424 |
+
self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
|
| 425 |
+
self.layer_norm_final_attn = keras.layers.LayerNormalization(
|
| 426 |
+
epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
def call(
|
| 430 |
+
self,
|
| 431 |
+
point_embeddings: tf.Tensor,
|
| 432 |
+
image_embeddings: tf.Tensor,
|
| 433 |
+
image_positional_embeddings: tf.Tensor,
|
| 434 |
+
output_attentions: Optional[bool] = None,
|
| 435 |
+
output_hidden_states: Optional[bool] = None,
|
| 436 |
+
return_dict: Optional[bool] = None,
|
| 437 |
+
) -> Union[Tuple, TFBaseModelOutput]:
|
| 438 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 439 |
+
output_hidden_states = (
|
| 440 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 441 |
+
)
|
| 442 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 443 |
+
|
| 444 |
+
all_attentions = ()
|
| 445 |
+
|
| 446 |
+
if image_embeddings is None:
|
| 447 |
+
raise ValueError("You have to specify an image_embedding")
|
| 448 |
+
|
| 449 |
+
image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
|
| 450 |
+
image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]
|
| 451 |
+
|
| 452 |
+
# Prepare queries
|
| 453 |
+
queries = point_embeddings
|
| 454 |
+
keys = image_embeddings
|
| 455 |
+
|
| 456 |
+
# Apply transformer blocks and final layernorm
|
| 457 |
+
for layer in self.layers:
|
| 458 |
+
queries, keys, attention_outputs = layer(
|
| 459 |
+
queries=queries,
|
| 460 |
+
keys=keys,
|
| 461 |
+
query_point_embedding=point_embeddings,
|
| 462 |
+
key_point_embedding=image_positional_embeddings,
|
| 463 |
+
output_attentions=output_attentions,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
if output_attentions:
|
| 467 |
+
all_attentions = all_attentions + (attention_outputs,)
|
| 468 |
+
|
| 469 |
+
# Apply the final attenion layer from the points to the image
|
| 470 |
+
query = queries + point_embeddings
|
| 471 |
+
key = keys + image_positional_embeddings
|
| 472 |
+
|
| 473 |
+
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
|
| 474 |
+
|
| 475 |
+
queries = queries + attn_out
|
| 476 |
+
queries = self.layer_norm_final_attn(queries)
|
| 477 |
+
return queries, keys, all_attentions
|
| 478 |
+
|
| 479 |
+
def build(self, input_shape=None):
|
| 480 |
+
if self.built:
|
| 481 |
+
return
|
| 482 |
+
self.built = True
|
| 483 |
+
if getattr(self, "final_attn_token_to_image", None) is not None:
|
| 484 |
+
with tf.name_scope(self.final_attn_token_to_image.name):
|
| 485 |
+
self.final_attn_token_to_image.build(None)
|
| 486 |
+
if getattr(self, "layer_norm_final_attn", None) is not None:
|
| 487 |
+
with tf.name_scope(self.layer_norm_final_attn.name):
|
| 488 |
+
self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size])
|
| 489 |
+
for layer in self.layers:
|
| 490 |
+
with tf.name_scope(layer.name):
|
| 491 |
+
layer.build(None)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class TFSamFeedForward(keras.layers.Layer):
|
| 495 |
+
def __init__(
|
| 496 |
+
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
|
| 497 |
+
):
|
| 498 |
+
super().__init__(**kwargs)
|
| 499 |
+
self.num_layers = num_layers
|
| 500 |
+
self.activation = keras.layers.ReLU()
|
| 501 |
+
self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
|
| 502 |
+
self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
|
| 503 |
+
self.layers = [
|
| 504 |
+
keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
|
| 505 |
+
for i in range(num_layers - 2)
|
| 506 |
+
]
|
| 507 |
+
self.sigmoid_output = sigmoid_output
|
| 508 |
+
self.hidden_dim = hidden_dim
|
| 509 |
+
self.input_dim = input_dim
|
| 510 |
+
|
| 511 |
+
def call(self, hidden_states):
|
| 512 |
+
hidden_states = self.proj_in(hidden_states)
|
| 513 |
+
hidden_states = self.activation(hidden_states)
|
| 514 |
+
for layer in self.layers:
|
| 515 |
+
hidden_states = self.activation(layer(hidden_states))
|
| 516 |
+
|
| 517 |
+
hidden_states = self.proj_out(hidden_states)
|
| 518 |
+
if self.sigmoid_output:
|
| 519 |
+
hidden_states = tf.sigmoid(hidden_states)
|
| 520 |
+
return hidden_states
|
| 521 |
+
|
| 522 |
+
def build(self, input_shape=None):
|
| 523 |
+
if self.built:
|
| 524 |
+
return
|
| 525 |
+
self.built = True
|
| 526 |
+
if getattr(self, "proj_in", None) is not None:
|
| 527 |
+
with tf.name_scope(self.proj_in.name):
|
| 528 |
+
self.proj_in.build([None, None, self.input_dim])
|
| 529 |
+
if getattr(self, "proj_out", None) is not None:
|
| 530 |
+
with tf.name_scope(self.proj_out.name):
|
| 531 |
+
self.proj_out.build([None, None, self.hidden_dim])
|
| 532 |
+
if getattr(self, "layers", None) is not None:
|
| 533 |
+
for layer in self.layers:
|
| 534 |
+
with tf.name_scope(layer.name):
|
| 535 |
+
layer.build([None, None, self.hidden_dim])
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class TFSamMaskDecoder(keras.layers.Layer):
|
| 539 |
+
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
|
| 540 |
+
super().__init__(**kwargs)
|
| 541 |
+
|
| 542 |
+
self.hidden_size = config.hidden_size
|
| 543 |
+
|
| 544 |
+
self.num_multimask_outputs = config.num_multimask_outputs
|
| 545 |
+
self.num_mask_tokens = config.num_multimask_outputs + 1
|
| 546 |
+
|
| 547 |
+
self.transformer = TFSamTwoWayTransformer(config, name="transformer")
|
| 548 |
+
|
| 549 |
+
self.upscale_conv1 = keras.layers.Conv2DTranspose(
|
| 550 |
+
self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
|
| 551 |
+
)
|
| 552 |
+
self.upscale_conv2 = keras.layers.Conv2DTranspose(
|
| 553 |
+
self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
|
| 554 |
+
)
|
| 555 |
+
self.upscale_layer_norm = TFSamLayerNorm(
|
| 556 |
+
self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
|
| 557 |
+
)
|
| 558 |
+
self.activation = tf.nn.gelu
|
| 559 |
+
|
| 560 |
+
mlps_list = []
|
| 561 |
+
for i in range(self.num_mask_tokens):
|
| 562 |
+
mlps_list += [
|
| 563 |
+
TFSamFeedForward(
|
| 564 |
+
self.hidden_size,
|
| 565 |
+
self.hidden_size,
|
| 566 |
+
self.hidden_size // 8,
|
| 567 |
+
3,
|
| 568 |
+
name=f"output_hypernetworks_mlps_._{i}",
|
| 569 |
+
)
|
| 570 |
+
]
|
| 571 |
+
self.output_hypernetworks_mlps = mlps_list
|
| 572 |
+
|
| 573 |
+
self.iou_prediction_head = TFSamFeedForward(
|
| 574 |
+
self.hidden_size,
|
| 575 |
+
config.iou_head_hidden_dim,
|
| 576 |
+
self.num_mask_tokens,
|
| 577 |
+
config.iou_head_depth,
|
| 578 |
+
name="iou_prediction_head",
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
def build(self, input_shape=None):
|
| 582 |
+
if self.built:
|
| 583 |
+
return
|
| 584 |
+
self.built = True
|
| 585 |
+
self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
|
| 586 |
+
self.mask_tokens = self.add_weight(
|
| 587 |
+
shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
if getattr(self, "transformer", None) is not None:
|
| 591 |
+
with tf.name_scope(self.transformer.name):
|
| 592 |
+
self.transformer.build(None)
|
| 593 |
+
if getattr(self, "upscale_conv1", None) is not None:
|
| 594 |
+
with tf.name_scope(self.upscale_conv1.name):
|
| 595 |
+
self.upscale_conv1.build([None, self.hidden_size, None, None])
|
| 596 |
+
if getattr(self, "upscale_conv2", None) is not None:
|
| 597 |
+
with tf.name_scope(self.upscale_conv2.name):
|
| 598 |
+
self.upscale_conv2.build([None, self.hidden_size // 4, None, None])
|
| 599 |
+
if getattr(self, "upscale_layer_norm", None) is not None:
|
| 600 |
+
with tf.name_scope(self.upscale_layer_norm.name):
|
| 601 |
+
self.upscale_layer_norm.build(None)
|
| 602 |
+
if getattr(self, "iou_prediction_head", None) is not None:
|
| 603 |
+
with tf.name_scope(self.iou_prediction_head.name):
|
| 604 |
+
self.iou_prediction_head.build(None)
|
| 605 |
+
for mlp in self.output_hypernetworks_mlps:
|
| 606 |
+
with tf.name_scope(mlp.name):
|
| 607 |
+
mlp.build(None)
|
| 608 |
+
|
| 609 |
+
def call(
|
| 610 |
+
self,
|
| 611 |
+
image_embeddings: tf.Tensor,
|
| 612 |
+
image_positional_embeddings: tf.Tensor,
|
| 613 |
+
sparse_prompt_embeddings: tf.Tensor,
|
| 614 |
+
dense_prompt_embeddings: tf.Tensor,
|
| 615 |
+
multimask_output: bool,
|
| 616 |
+
output_attentions: Optional[bool] = None,
|
| 617 |
+
) -> Tuple[tf.Tensor, tf.Tensor]:
|
| 618 |
+
batch_size, num_channels, height, width = shape_list(image_embeddings)
|
| 619 |
+
point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1])
|
| 620 |
+
|
| 621 |
+
output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32)
|
| 622 |
+
output_tokens = tf.tile(
|
| 623 |
+
output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]
|
| 624 |
+
) # Should be (batch_size, point_size, 5, 32)
|
| 625 |
+
|
| 626 |
+
# Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
|
| 627 |
+
# happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
|
| 628 |
+
# it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
|
| 629 |
+
if shape_list(sparse_prompt_embeddings)[1] != 0:
|
| 630 |
+
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
|
| 631 |
+
else:
|
| 632 |
+
tokens = output_tokens
|
| 633 |
+
point_embeddings = tf.cast(tokens, self.iou_token.dtype)
|
| 634 |
+
|
| 635 |
+
image_embeddings = image_embeddings + dense_prompt_embeddings
|
| 636 |
+
image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0)
|
| 637 |
+
image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0)
|
| 638 |
+
|
| 639 |
+
point_embedding, image_embeddings, attentions = self.transformer(
|
| 640 |
+
point_embeddings=point_embeddings,
|
| 641 |
+
image_embeddings=image_embeddings,
|
| 642 |
+
image_positional_embeddings=image_positional_embeddings,
|
| 643 |
+
output_attentions=output_attentions,
|
| 644 |
+
)
|
| 645 |
+
iou_token_out = point_embedding[:, :, 0, :]
|
| 646 |
+
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
|
| 647 |
+
|
| 648 |
+
image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2))
|
| 649 |
+
image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width])
|
| 650 |
+
|
| 651 |
+
upscaled_embedding = self.upscale_conv1(image_embeddings)
|
| 652 |
+
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
|
| 653 |
+
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
|
| 654 |
+
|
| 655 |
+
hyper_in_list = []
|
| 656 |
+
for i in range(self.num_mask_tokens):
|
| 657 |
+
current_mlp = self.output_hypernetworks_mlps[i]
|
| 658 |
+
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
|
| 659 |
+
hyper_in = tf.stack(hyper_in_list, axis=2)
|
| 660 |
+
|
| 661 |
+
_, num_channels, height, width = shape_list(upscaled_embedding)
|
| 662 |
+
upscaled_embedding = tf.reshape(
|
| 663 |
+
upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]
|
| 664 |
+
)
|
| 665 |
+
masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width])
|
| 666 |
+
|
| 667 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
| 668 |
+
|
| 669 |
+
if multimask_output:
|
| 670 |
+
mask_slice = slice(1, None)
|
| 671 |
+
else:
|
| 672 |
+
mask_slice = slice(0, 1)
|
| 673 |
+
masks = masks[:, :, mask_slice, :, :]
|
| 674 |
+
iou_pred = iou_pred[:, :, mask_slice]
|
| 675 |
+
|
| 676 |
+
outputs = (masks, iou_pred)
|
| 677 |
+
|
| 678 |
+
if output_attentions:
|
| 679 |
+
outputs = outputs + (attentions,)
|
| 680 |
+
else:
|
| 681 |
+
outputs = outputs + (None,)
|
| 682 |
+
|
| 683 |
+
return outputs
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
class TFSamPositionalEmbedding(keras.layers.Layer):
|
| 687 |
+
def __init__(self, config, **kwargs):
|
| 688 |
+
super().__init__(**kwargs)
|
| 689 |
+
self.scale = config.hidden_size // 2
|
| 690 |
+
self.config = config
|
| 691 |
+
|
| 692 |
+
def build(self, input_shape):
|
| 693 |
+
# TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized?
|
| 694 |
+
self.positional_embedding = self.add_weight(
|
| 695 |
+
name="positional_embedding",
|
| 696 |
+
shape=(2, self.config.num_pos_feats),
|
| 697 |
+
initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
|
| 698 |
+
trainable=False,
|
| 699 |
+
)
|
| 700 |
+
super().build(input_shape)
|
| 701 |
+
|
| 702 |
+
def call(self, input_coords, input_shape=None):
|
| 703 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 704 |
+
coordinates = tf.identity(input_coords)
|
| 705 |
+
|
| 706 |
+
if input_shape is not None:
|
| 707 |
+
coordinates = tf.stack(
|
| 708 |
+
[
|
| 709 |
+
tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],
|
| 710 |
+
tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],
|
| 711 |
+
],
|
| 712 |
+
axis=-1,
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 716 |
+
coordinates = 2 * coordinates - 1
|
| 717 |
+
coordinates = tf.cast(coordinates, self.positional_embedding.dtype)
|
| 718 |
+
coordinates = tf.matmul(coordinates, self.positional_embedding)
|
| 719 |
+
coordinates = 2 * np.pi * coordinates
|
| 720 |
+
# outputs d_1 x ... x d_n x channel shape
|
| 721 |
+
return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
class TFSamMaskEmbedding(keras.layers.Layer):
|
| 725 |
+
def __init__(self, config: SamPromptEncoderConfig, **kwargs):
|
| 726 |
+
super().__init__(**kwargs)
|
| 727 |
+
self.mask_input_channels = config.mask_input_channels // 4
|
| 728 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 729 |
+
self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")
|
| 730 |
+
self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")
|
| 731 |
+
self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
|
| 732 |
+
self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
|
| 733 |
+
self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
|
| 734 |
+
self.config = config
|
| 735 |
+
|
| 736 |
+
def call(self, masks):
|
| 737 |
+
masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last
|
| 738 |
+
hidden_states = self.conv1(masks)
|
| 739 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 740 |
+
hidden_states = self.activation(hidden_states)
|
| 741 |
+
|
| 742 |
+
hidden_states = self.conv2(hidden_states)
|
| 743 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 744 |
+
hidden_states = self.activation(hidden_states)
|
| 745 |
+
dense_embeddings = self.conv3(hidden_states)
|
| 746 |
+
dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first
|
| 747 |
+
return dense_embeddings
|
| 748 |
+
|
| 749 |
+
def build(self, input_shape=None):
|
| 750 |
+
# This class needs an explicit build method because it isn't called with the standard dummy inputs
|
| 751 |
+
if self.built:
|
| 752 |
+
return
|
| 753 |
+
self.built = True
|
| 754 |
+
with tf.name_scope("conv1"):
|
| 755 |
+
self.conv1.build([None, None, None, 1])
|
| 756 |
+
with tf.name_scope("conv2"):
|
| 757 |
+
self.conv2.build([None, None, None, self.mask_input_channels])
|
| 758 |
+
with tf.name_scope("conv3"):
|
| 759 |
+
self.conv3.build([None, None, None, self.mask_input_channels * 4])
|
| 760 |
+
with tf.name_scope("layer_norm1"):
|
| 761 |
+
self.layer_norm1.build([None, None, None, self.mask_input_channels])
|
| 762 |
+
with tf.name_scope("layer_norm2"):
|
| 763 |
+
self.layer_norm2.build([None, None, None, self.mask_input_channels * 4])
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
class TFSamPromptEncoder(keras.layers.Layer):
|
| 767 |
+
def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
|
| 768 |
+
super().__init__(**kwargs)
|
| 769 |
+
self.shared_embedding = shared_patch_embedding
|
| 770 |
+
self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")
|
| 771 |
+
self.no_mask_embed = None
|
| 772 |
+
|
| 773 |
+
self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
|
| 774 |
+
self.input_image_size = config.image_size
|
| 775 |
+
|
| 776 |
+
self.point_embed = []
|
| 777 |
+
self.hidden_size = config.hidden_size
|
| 778 |
+
self.not_a_point_embed = None
|
| 779 |
+
self.config = config
|
| 780 |
+
|
| 781 |
+
def build(self, input_shape=None):
|
| 782 |
+
self.no_mask_embed = self.add_weight(
|
| 783 |
+
name="no_mask_embed.weight",
|
| 784 |
+
shape=(1, self.hidden_size),
|
| 785 |
+
initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
|
| 786 |
+
trainable=True,
|
| 787 |
+
)
|
| 788 |
+
self.point_embed = [
|
| 789 |
+
self.add_weight(
|
| 790 |
+
name=f"point_embed_._{i}.weight",
|
| 791 |
+
shape=(1, self.hidden_size),
|
| 792 |
+
initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
|
| 793 |
+
trainable=True,
|
| 794 |
+
)
|
| 795 |
+
for i in range(self.config.num_point_embeddings)
|
| 796 |
+
]
|
| 797 |
+
self.not_a_point_embed = self.add_weight(
|
| 798 |
+
name="not_a_point_embed.weight",
|
| 799 |
+
shape=(1, self.hidden_size),
|
| 800 |
+
initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
|
| 801 |
+
trainable=True,
|
| 802 |
+
)
|
| 803 |
+
with tf.name_scope("mask_embed"):
|
| 804 |
+
# We must explicitly build the mask embed because it isn't touched by the standard dummy inputs
|
| 805 |
+
self.mask_embed.build(
|
| 806 |
+
(None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
if self.built:
|
| 810 |
+
return
|
| 811 |
+
self.built = True
|
| 812 |
+
if getattr(self, "mask_embed", None) is not None:
|
| 813 |
+
with tf.name_scope(self.mask_embed.name):
|
| 814 |
+
self.mask_embed.build(None)
|
| 815 |
+
|
| 816 |
+
def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
|
| 817 |
+
"""Embeds point prompts."""
|
| 818 |
+
points = points + 0.5 # Shift to center of pixel
|
| 819 |
+
if pad:
|
| 820 |
+
target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
|
| 821 |
+
target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
|
| 822 |
+
padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
|
| 823 |
+
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
|
| 824 |
+
points = tf.concat([points, padding_point], axis=2)
|
| 825 |
+
labels = tf.concat([labels, padding_label], axis=2)
|
| 826 |
+
input_shape = (self.input_image_size, self.input_image_size)
|
| 827 |
+
point_embedding = self.shared_embedding(points, input_shape)
|
| 828 |
+
|
| 829 |
+
point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)
|
| 830 |
+
|
| 831 |
+
point_embedding = tf.where(
|
| 832 |
+
labels[..., None] != -10,
|
| 833 |
+
point_embedding,
|
| 834 |
+
tf.zeros_like(point_embedding),
|
| 835 |
+
)
|
| 836 |
+
point_embedding = tf.where(
|
| 837 |
+
(labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
|
| 838 |
+
)
|
| 839 |
+
point_embedding = tf.where(
|
| 840 |
+
(labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
|
| 841 |
+
)
|
| 842 |
+
return point_embedding
|
| 843 |
+
|
| 844 |
+
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
|
| 845 |
+
"""Embeds box prompts."""
|
| 846 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
| 847 |
+
batch_size, nb_boxes = shape_list(boxes)[:2]
|
| 848 |
+
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
|
| 849 |
+
input_shape = (self.input_image_size, self.input_image_size)
|
| 850 |
+
corner_embedding = self.shared_embedding(coords, input_shape)
|
| 851 |
+
corner_embedding += tf.where(
|
| 852 |
+
tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
|
| 853 |
+
self.point_embed[2][0],
|
| 854 |
+
self.point_embed[3][0],
|
| 855 |
+
)
|
| 856 |
+
return corner_embedding
|
| 857 |
+
|
| 858 |
+
def call(
|
| 859 |
+
self,
|
| 860 |
+
batch_size: Optional[int],
|
| 861 |
+
input_points: Optional[Tuple[tf.Tensor, tf.Tensor]],
|
| 862 |
+
input_labels: tf.Tensor | None,
|
| 863 |
+
input_boxes: tf.Tensor | None,
|
| 864 |
+
input_masks: tf.Tensor | None,
|
| 865 |
+
) -> Tuple[tf.Tensor, tf.Tensor]:
|
| 866 |
+
"""
|
| 867 |
+
Embeds different types of prompts, returning both sparse and dense embeddings.
|
| 868 |
+
|
| 869 |
+
Args:
|
| 870 |
+
points (`tf.Tensor`, *optional*):
|
| 871 |
+
point coordinates and labels to embed.
|
| 872 |
+
boxes (`tf.Tensor`, *optional*):
|
| 873 |
+
boxes to embed
|
| 874 |
+
masks (`tf.Tensor`, *optional*):
|
| 875 |
+
masks to embed
|
| 876 |
+
"""
|
| 877 |
+
sparse_embeddings = None
|
| 878 |
+
if input_points is not None:
|
| 879 |
+
batch_size, point_batch_size = shape_list(input_points)[:2]
|
| 880 |
+
if input_labels is None:
|
| 881 |
+
raise ValueError("If points are provided, labels must also be provided.")
|
| 882 |
+
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
| 883 |
+
sparse_embeddings = tf.zeros(
|
| 884 |
+
(batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
|
| 885 |
+
)
|
| 886 |
+
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
|
| 887 |
+
if input_boxes is not None:
|
| 888 |
+
batch_size = shape_list(input_boxes)[0]
|
| 889 |
+
box_embeddings = self._embed_boxes(input_boxes)
|
| 890 |
+
if sparse_embeddings is None:
|
| 891 |
+
sparse_embeddings = box_embeddings
|
| 892 |
+
else:
|
| 893 |
+
sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)
|
| 894 |
+
if input_masks is not None:
|
| 895 |
+
dense_embeddings = self.mask_embed(input_masks)
|
| 896 |
+
else:
|
| 897 |
+
dense_embeddings = self.no_mask_embed[0]
|
| 898 |
+
dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
|
| 899 |
+
dense_embeddings = tf.tile(
|
| 900 |
+
dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
|
| 901 |
+
)
|
| 902 |
+
if sparse_embeddings is None:
|
| 903 |
+
sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)
|
| 904 |
+
|
| 905 |
+
return sparse_embeddings, dense_embeddings
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
class TFSamVisionAttention(keras.layers.Layer):
|
| 909 |
+
"""Multi-head Attention block with relative position embeddings."""
|
| 910 |
+
|
| 911 |
+
def __init__(self, config, window_size, **kwargs):
|
| 912 |
+
super().__init__(**kwargs)
|
| 913 |
+
input_size = (
|
| 914 |
+
(config.image_size // config.patch_size, config.image_size // config.patch_size)
|
| 915 |
+
if window_size == 0
|
| 916 |
+
else (window_size, window_size)
|
| 917 |
+
)
|
| 918 |
+
self.input_size = input_size
|
| 919 |
+
|
| 920 |
+
self.num_attention_heads = config.num_attention_heads
|
| 921 |
+
head_dim = config.hidden_size // config.num_attention_heads
|
| 922 |
+
self.head_dim = head_dim
|
| 923 |
+
self.scale = head_dim**-0.5
|
| 924 |
+
self.dropout = config.attention_dropout
|
| 925 |
+
|
| 926 |
+
self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
|
| 927 |
+
self.proj = keras.layers.Dense(config.hidden_size, name="proj")
|
| 928 |
+
|
| 929 |
+
self.use_rel_pos = config.use_rel_pos
|
| 930 |
+
if self.use_rel_pos:
|
| 931 |
+
if input_size is None:
|
| 932 |
+
raise ValueError("Input size must be provided if using relative positional encoding.")
|
| 933 |
+
self.config = config
|
| 934 |
+
|
| 935 |
+
def build(self, input_shape=None):
|
| 936 |
+
if self.input_size is not None:
|
| 937 |
+
# initialize relative positional embeddings
|
| 938 |
+
self.rel_pos_h = self.add_weight(
|
| 939 |
+
shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
|
| 940 |
+
)
|
| 941 |
+
self.rel_pos_w = self.add_weight(
|
| 942 |
+
shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
if self.built:
|
| 946 |
+
return
|
| 947 |
+
self.built = True
|
| 948 |
+
if getattr(self, "qkv", None) is not None:
|
| 949 |
+
with tf.name_scope(self.qkv.name):
|
| 950 |
+
self.qkv.build([None, None, self.config.hidden_size])
|
| 951 |
+
if getattr(self, "proj", None) is not None:
|
| 952 |
+
with tf.name_scope(self.proj.name):
|
| 953 |
+
self.proj.build([None, None, self.config.hidden_size])
|
| 954 |
+
|
| 955 |
+
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
|
| 956 |
+
"""
|
| 957 |
+
Get relative positional embeddings according to the relative positions of
|
| 958 |
+
query and key sizes.
|
| 959 |
+
|
| 960 |
+
Args:
|
| 961 |
+
q_size (int):
|
| 962 |
+
size of the query.
|
| 963 |
+
k_size (int):
|
| 964 |
+
size of key k.
|
| 965 |
+
rel_pos (`tf.Tensor`):
|
| 966 |
+
relative position embeddings (L, channel).
|
| 967 |
+
|
| 968 |
+
Returns:
|
| 969 |
+
Extracted positional embeddings according to relative positions.
|
| 970 |
+
"""
|
| 971 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
| 972 |
+
# Interpolate rel pos if needed.
|
| 973 |
+
if rel_pos.shape[0] != max_rel_dist:
|
| 974 |
+
# Interpolate rel pos.
|
| 975 |
+
rel_pos_resized = tf.image.resize(
|
| 976 |
+
tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
|
| 977 |
+
size=(max_rel_dist, rel_pos.shape[1]),
|
| 978 |
+
method="bilinear",
|
| 979 |
+
)
|
| 980 |
+
rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
|
| 981 |
+
else:
|
| 982 |
+
rel_pos_resized = rel_pos
|
| 983 |
+
|
| 984 |
+
# Scale the coords with short length if shapes for q and k are different.
|
| 985 |
+
q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
|
| 986 |
+
k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
|
| 987 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
| 988 |
+
|
| 989 |
+
return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
|
| 990 |
+
|
| 991 |
+
def get_decomposed_rel_pos(
|
| 992 |
+
self,
|
| 993 |
+
query: tf.Tensor,
|
| 994 |
+
rel_pos_h: tf.Tensor,
|
| 995 |
+
rel_pos_w: tf.Tensor,
|
| 996 |
+
q_size: Tuple[int, int],
|
| 997 |
+
k_size: Tuple[int, int],
|
| 998 |
+
) -> tf.Tensor:
|
| 999 |
+
"""
|
| 1000 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
| 1001 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
| 1002 |
+
|
| 1003 |
+
Args:
|
| 1004 |
+
query (`tf.Tensor`):
|
| 1005 |
+
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
| 1006 |
+
rel_pos_h (`tf.Tensor`):
|
| 1007 |
+
relative position embeddings (Lh, channel) for height axis.
|
| 1008 |
+
rel_pos_w (`tf.Tensor`):
|
| 1009 |
+
relative position embeddings (Lw, channel) for width axis.
|
| 1010 |
+
q_size (tuple):
|
| 1011 |
+
spatial sequence size of query q with (query_height, query_width).
|
| 1012 |
+
k_size (tuple):
|
| 1013 |
+
spatial sequence size of key k with (key_height, key_width).
|
| 1014 |
+
|
| 1015 |
+
Returns:
|
| 1016 |
+
decomposed_rel_pos (`torch.Tensor`):
|
| 1017 |
+
decomposed relative position embeddings.
|
| 1018 |
+
"""
|
| 1019 |
+
query_height, query_width = q_size
|
| 1020 |
+
key_height, key_width = k_size
|
| 1021 |
+
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
|
| 1022 |
+
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
|
| 1023 |
+
|
| 1024 |
+
batch_size, _, dim = shape_list(query)
|
| 1025 |
+
reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
|
| 1026 |
+
rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
| 1027 |
+
rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
| 1028 |
+
|
| 1029 |
+
rel_h = tf.expand_dims(rel_h, axis=-1)
|
| 1030 |
+
rel_w = tf.expand_dims(rel_w, axis=-2)
|
| 1031 |
+
decomposed_rel_pos = rel_h + rel_w
|
| 1032 |
+
|
| 1033 |
+
return decomposed_rel_pos
|
| 1034 |
+
|
| 1035 |
+
def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
|
| 1036 |
+
batch_size, height, width, _ = shape_list(hidden_states)
|
| 1037 |
+
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
| 1038 |
+
qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
|
| 1039 |
+
qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
|
| 1040 |
+
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
| 1041 |
+
query, key, value = tf.unstack(
|
| 1042 |
+
tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
|
| 1043 |
+
)
|
| 1044 |
+
attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
|
| 1045 |
+
|
| 1046 |
+
if self.use_rel_pos:
|
| 1047 |
+
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
| 1048 |
+
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
| 1049 |
+
)
|
| 1050 |
+
decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights))
|
| 1051 |
+
attn_weights = attn_weights + decomposed_rel_pos
|
| 1052 |
+
|
| 1053 |
+
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
| 1054 |
+
|
| 1055 |
+
if training:
|
| 1056 |
+
attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
|
| 1057 |
+
else:
|
| 1058 |
+
attn_probs = attn_weights
|
| 1059 |
+
|
| 1060 |
+
attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
|
| 1061 |
+
attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
|
| 1062 |
+
attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size))
|
| 1063 |
+
|
| 1064 |
+
attn_output = self.proj(attn_output)
|
| 1065 |
+
|
| 1066 |
+
if output_attentions:
|
| 1067 |
+
outputs = (attn_output, attn_weights)
|
| 1068 |
+
else:
|
| 1069 |
+
outputs = (attn_output, None)
|
| 1070 |
+
|
| 1071 |
+
return outputs
|
| 1072 |
+
|
| 1073 |
+
|
| 1074 |
+
class TFSamVisionLayer(keras.layers.Layer):
|
| 1075 |
+
def __init__(self, config, window_size, **kwargs):
|
| 1076 |
+
super().__init__(**kwargs)
|
| 1077 |
+
self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
|
| 1078 |
+
self.attn = TFSamVisionAttention(config, window_size, name="attn")
|
| 1079 |
+
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
|
| 1080 |
+
self.mlp = TFSamMLPBlock(config, name="mlp")
|
| 1081 |
+
self.window_size = window_size
|
| 1082 |
+
self.config = config
|
| 1083 |
+
|
| 1084 |
+
def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
|
| 1085 |
+
batch_size, height, width, channel = shape_list(hidden_states)
|
| 1086 |
+
|
| 1087 |
+
pad_h = (window_size - height % window_size) % window_size
|
| 1088 |
+
pad_w = (window_size - width % window_size) % window_size
|
| 1089 |
+
if pad_h > 0 or pad_w > 0:
|
| 1090 |
+
hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
|
| 1091 |
+
pad_height, pad_width = height + pad_h, width + pad_w
|
| 1092 |
+
|
| 1093 |
+
hidden_states = tf.reshape(
|
| 1094 |
+
hidden_states,
|
| 1095 |
+
[batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
|
| 1096 |
+
)
|
| 1097 |
+
windows = tf.reshape(
|
| 1098 |
+
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
|
| 1099 |
+
)
|
| 1100 |
+
return windows, (pad_height, pad_width)
|
| 1101 |
+
|
| 1102 |
+
def window_unpartition(
|
| 1103 |
+
self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
|
| 1104 |
+
) -> tf.Tensor:
|
| 1105 |
+
pad_height, pad_width = padding_shape
|
| 1106 |
+
height, width = original_shape
|
| 1107 |
+
batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
|
| 1108 |
+
hidden_states = tf.reshape(
|
| 1109 |
+
windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
|
| 1110 |
+
)
|
| 1111 |
+
hidden_states = tf.reshape(
|
| 1112 |
+
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
if pad_height > height or pad_width > width:
|
| 1116 |
+
hidden_states = hidden_states[:, :height, :width, :]
|
| 1117 |
+
return hidden_states
|
| 1118 |
+
|
| 1119 |
+
def call(
|
| 1120 |
+
self,
|
| 1121 |
+
hidden_states: tf.Tensor,
|
| 1122 |
+
output_attentions: Optional[bool] = False,
|
| 1123 |
+
training: Optional[bool] = False,
|
| 1124 |
+
) -> Tuple[tf.Tensor]:
|
| 1125 |
+
residual = hidden_states
|
| 1126 |
+
|
| 1127 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 1128 |
+
if self.window_size > 0:
|
| 1129 |
+
height, width = hidden_states.shape[1], hidden_states.shape[2]
|
| 1130 |
+
hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
|
| 1131 |
+
|
| 1132 |
+
hidden_states, attn_weights = self.attn(
|
| 1133 |
+
hidden_states=hidden_states,
|
| 1134 |
+
output_attentions=output_attentions,
|
| 1135 |
+
training=training,
|
| 1136 |
+
)
|
| 1137 |
+
if self.window_size > 0:
|
| 1138 |
+
hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
|
| 1139 |
+
|
| 1140 |
+
hidden_states = residual + hidden_states
|
| 1141 |
+
layernorm_output = self.layer_norm2(hidden_states)
|
| 1142 |
+
hidden_states = hidden_states + self.mlp(layernorm_output)
|
| 1143 |
+
|
| 1144 |
+
outputs = (hidden_states,)
|
| 1145 |
+
if output_attentions:
|
| 1146 |
+
outputs += (attn_weights,)
|
| 1147 |
+
|
| 1148 |
+
return outputs
|
| 1149 |
+
|
| 1150 |
+
def build(self, input_shape=None):
|
| 1151 |
+
if self.built:
|
| 1152 |
+
return
|
| 1153 |
+
self.built = True
|
| 1154 |
+
if getattr(self, "layer_norm1", None) is not None:
|
| 1155 |
+
with tf.name_scope(self.layer_norm1.name):
|
| 1156 |
+
self.layer_norm1.build([None, None, None, self.config.hidden_size])
|
| 1157 |
+
if getattr(self, "attn", None) is not None:
|
| 1158 |
+
with tf.name_scope(self.attn.name):
|
| 1159 |
+
self.attn.build(None)
|
| 1160 |
+
if getattr(self, "layer_norm2", None) is not None:
|
| 1161 |
+
with tf.name_scope(self.layer_norm2.name):
|
| 1162 |
+
self.layer_norm2.build([None, None, None, self.config.hidden_size])
|
| 1163 |
+
if getattr(self, "mlp", None) is not None:
|
| 1164 |
+
with tf.name_scope(self.mlp.name):
|
| 1165 |
+
self.mlp.build(None)
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
class TFSamVisionNeck(keras.layers.Layer):
|
| 1169 |
+
def __init__(self, config: SamVisionConfig, **kwargs):
|
| 1170 |
+
super().__init__(**kwargs)
|
| 1171 |
+
self.config = config
|
| 1172 |
+
|
| 1173 |
+
self.conv1 = keras.layers.Conv2D(
|
| 1174 |
+
config.output_channels,
|
| 1175 |
+
kernel_size=1,
|
| 1176 |
+
use_bias=False,
|
| 1177 |
+
name="conv1",
|
| 1178 |
+
)
|
| 1179 |
+
self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1")
|
| 1180 |
+
self.conv2 = keras.layers.Conv2D(
|
| 1181 |
+
config.output_channels,
|
| 1182 |
+
kernel_size=3,
|
| 1183 |
+
padding="same",
|
| 1184 |
+
use_bias=False,
|
| 1185 |
+
name="conv2",
|
| 1186 |
+
)
|
| 1187 |
+
self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2")
|
| 1188 |
+
|
| 1189 |
+
def call(self, hidden_states):
|
| 1190 |
+
hidden_states = self.conv1(hidden_states)
|
| 1191 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 1192 |
+
|
| 1193 |
+
hidden_states = self.conv2(hidden_states)
|
| 1194 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 1195 |
+
hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
|
| 1196 |
+
return hidden_states
|
| 1197 |
+
|
| 1198 |
+
def build(self, input_shape=None):
|
| 1199 |
+
if self.built:
|
| 1200 |
+
return
|
| 1201 |
+
self.built = True
|
| 1202 |
+
if getattr(self, "conv1", None) is not None:
|
| 1203 |
+
with tf.name_scope(self.conv1.name):
|
| 1204 |
+
self.conv1.build([None, None, None, self.config.hidden_size])
|
| 1205 |
+
if getattr(self, "layer_norm1", None) is not None:
|
| 1206 |
+
with tf.name_scope(self.layer_norm1.name):
|
| 1207 |
+
self.layer_norm1.build(None)
|
| 1208 |
+
if getattr(self, "conv2", None) is not None:
|
| 1209 |
+
with tf.name_scope(self.conv2.name):
|
| 1210 |
+
self.conv2.build([None, None, None, self.config.output_channels])
|
| 1211 |
+
if getattr(self, "layer_norm2", None) is not None:
|
| 1212 |
+
with tf.name_scope(self.layer_norm2.name):
|
| 1213 |
+
self.layer_norm2.build(None)
|
| 1214 |
+
|
| 1215 |
+
|
| 1216 |
+
class TFSamVisionEncoder(keras.layers.Layer):
|
| 1217 |
+
def __init__(self, config: SamVisionConfig, **kwargs):
|
| 1218 |
+
super().__init__(**kwargs)
|
| 1219 |
+
self.config = config
|
| 1220 |
+
self.image_size = config.image_size
|
| 1221 |
+
|
| 1222 |
+
self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed")
|
| 1223 |
+
|
| 1224 |
+
self.pos_embed = None
|
| 1225 |
+
|
| 1226 |
+
self.layers = []
|
| 1227 |
+
for i in range(config.num_hidden_layers):
|
| 1228 |
+
layer = TFSamVisionLayer(
|
| 1229 |
+
config,
|
| 1230 |
+
window_size=config.window_size if i not in config.global_attn_indexes else 0,
|
| 1231 |
+
name=f"layers_._{i}",
|
| 1232 |
+
)
|
| 1233 |
+
self.layers.append(layer)
|
| 1234 |
+
|
| 1235 |
+
self.neck = TFSamVisionNeck(config, name="neck")
|
| 1236 |
+
|
| 1237 |
+
def build(self, input_shape=None):
|
| 1238 |
+
if self.built:
|
| 1239 |
+
return
|
| 1240 |
+
self.built = True
|
| 1241 |
+
if self.config.use_abs_pos:
|
| 1242 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 1243 |
+
self.pos_embed = self.add_weight(
|
| 1244 |
+
shape=[
|
| 1245 |
+
1,
|
| 1246 |
+
self.config.image_size // self.config.patch_size,
|
| 1247 |
+
self.config.image_size // self.config.patch_size,
|
| 1248 |
+
self.config.hidden_size,
|
| 1249 |
+
],
|
| 1250 |
+
initializer="zeros",
|
| 1251 |
+
trainable=True,
|
| 1252 |
+
name="pos_embed",
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
if getattr(self, "patch_embed", None) is not None:
|
| 1256 |
+
with tf.name_scope(self.patch_embed.name):
|
| 1257 |
+
self.patch_embed.build(None)
|
| 1258 |
+
if getattr(self, "neck", None) is not None:
|
| 1259 |
+
with tf.name_scope(self.neck.name):
|
| 1260 |
+
self.neck.build(None)
|
| 1261 |
+
for layer in self.layers:
|
| 1262 |
+
with tf.name_scope(layer.name):
|
| 1263 |
+
layer.build(None)
|
| 1264 |
+
|
| 1265 |
+
def get_input_embeddings(self):
|
| 1266 |
+
return self.patch_embed
|
| 1267 |
+
|
| 1268 |
+
def call(
|
| 1269 |
+
self,
|
| 1270 |
+
pixel_values: tf.Tensor | None = None,
|
| 1271 |
+
output_attentions: Optional[bool] = None,
|
| 1272 |
+
output_hidden_states: Optional[bool] = None,
|
| 1273 |
+
return_dict: Optional[bool] = None,
|
| 1274 |
+
training: Optional[bool] = False,
|
| 1275 |
+
) -> Union[Tuple, TFSamVisionEncoderOutput]:
|
| 1276 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1277 |
+
output_hidden_states = (
|
| 1278 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1279 |
+
)
|
| 1280 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1281 |
+
|
| 1282 |
+
if pixel_values is None:
|
| 1283 |
+
raise ValueError("You have to specify pixel_values")
|
| 1284 |
+
|
| 1285 |
+
hidden_states = self.patch_embed(pixel_values)
|
| 1286 |
+
if self.pos_embed is not None:
|
| 1287 |
+
hidden_states = hidden_states + self.pos_embed
|
| 1288 |
+
|
| 1289 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1290 |
+
all_self_attentions = () if output_attentions else None
|
| 1291 |
+
|
| 1292 |
+
for i, layer_module in enumerate(self.layers):
|
| 1293 |
+
if output_hidden_states:
|
| 1294 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1295 |
+
|
| 1296 |
+
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training)
|
| 1297 |
+
|
| 1298 |
+
hidden_states = layer_outputs[0]
|
| 1299 |
+
|
| 1300 |
+
if output_attentions:
|
| 1301 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 1302 |
+
|
| 1303 |
+
if output_hidden_states:
|
| 1304 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1305 |
+
|
| 1306 |
+
hidden_states = self.neck(hidden_states)
|
| 1307 |
+
|
| 1308 |
+
if not return_dict:
|
| 1309 |
+
outputs = (hidden_states,)
|
| 1310 |
+
if output_hidden_states:
|
| 1311 |
+
outputs = outputs + (all_hidden_states,)
|
| 1312 |
+
if output_attentions:
|
| 1313 |
+
outputs = outputs + (all_self_attentions,)
|
| 1314 |
+
return outputs
|
| 1315 |
+
|
| 1316 |
+
return TFSamVisionEncoderOutput(
|
| 1317 |
+
last_hidden_state=hidden_states,
|
| 1318 |
+
hidden_states=all_hidden_states,
|
| 1319 |
+
attentions=all_self_attentions,
|
| 1320 |
+
)
|
| 1321 |
+
|
| 1322 |
+
|
| 1323 |
+
class TFSamPreTrainedModel(TFPreTrainedModel):
|
| 1324 |
+
config_class = SamConfig
|
| 1325 |
+
base_model_prefix = "sam"
|
| 1326 |
+
main_input_name = "pixel_values"
|
| 1327 |
+
|
| 1328 |
+
|
| 1329 |
+
SAM_START_DOCSTRING = r"""
|
| 1330 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1331 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1332 |
+
etc.)
|
| 1333 |
+
|
| 1334 |
+
This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
|
| 1335 |
+
subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to
|
| 1336 |
+
general usage and behavior.
|
| 1337 |
+
|
| 1338 |
+
Parameters:
|
| 1339 |
+
config ([`SamConfig`]): Model configuration class with all the parameters of the model.
|
| 1340 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1341 |
+
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1342 |
+
"""
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
SAM_INPUTS_DOCSTRING = r"""
|
| 1346 |
+
Args:
|
| 1347 |
+
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1348 |
+
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
|
| 1349 |
+
details.
|
| 1350 |
+
input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`):
|
| 1351 |
+
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
|
| 1352 |
+
better results. The points can be obtained by passing a list of list of list to the processor that will
|
| 1353 |
+
create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second
|
| 1354 |
+
dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per
|
| 1355 |
+
input point), the third dimension is the number of points per segmentation mask (it is possible to pass
|
| 1356 |
+
multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
|
| 1357 |
+
coordinates of the point. If a different number of points is passed either for each image, or for each
|
| 1358 |
+
mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
|
| 1359 |
+
computation of the embedding will be skipped for these points using the labels.
|
| 1360 |
+
input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`):
|
| 1361 |
+
Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
|
| 1362 |
+
official implementation, there are 3 types of labels
|
| 1363 |
+
|
| 1364 |
+
- `1`: the point is a point that contains the object of interest
|
| 1365 |
+
- `0`: the point is a point that does not contain the object of interest
|
| 1366 |
+
- `-1`: the point corresponds to the background
|
| 1367 |
+
|
| 1368 |
+
We added the label:
|
| 1369 |
+
|
| 1370 |
+
- `-10`: the point is a padding point, thus should be ignored by the prompt encoder
|
| 1371 |
+
|
| 1372 |
+
The padding labels should be automatically done by the processor.
|
| 1373 |
+
input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`):
|
| 1374 |
+
Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
|
| 1375 |
+
much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
|
| 1376 |
+
that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size,
|
| 1377 |
+
the number of boxes per image and the coordinates of the top left and botton right point of the box. In the
|
| 1378 |
+
order (`x1`, `y1`, `x2`, `y2`):
|
| 1379 |
+
|
| 1380 |
+
- `x1`: the x coordinate of the top left point of the input box
|
| 1381 |
+
- `y1`: the y coordinate of the top left point of the input box
|
| 1382 |
+
- `x2`: the x coordinate of the bottom right point of the input box
|
| 1383 |
+
- `y2`: the y coordinate of the bottom right point of the input box
|
| 1384 |
+
|
| 1385 |
+
input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
|
| 1386 |
+
SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
|
| 1387 |
+
generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
|
| 1388 |
+
manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
|
| 1389 |
+
|
| 1390 |
+
image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`):
|
| 1391 |
+
Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
|
| 1392 |
+
efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
|
| 1393 |
+
method, and then feed them to the `call` method instead of feeding the `pixel_values`.
|
| 1394 |
+
multimask_output (`bool`, *optional*):
|
| 1395 |
+
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
|
| 1396 |
+
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
|
| 1397 |
+
"best" mask, by specifying `multimask_output=False`.
|
| 1398 |
+
output_attentions (`bool`, *optional*):
|
| 1399 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1400 |
+
tensors for more detail.
|
| 1401 |
+
output_hidden_states (`bool`, *optional*):
|
| 1402 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1403 |
+
more detail.
|
| 1404 |
+
return_dict (`bool`, *optional*):
|
| 1405 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1406 |
+
"""
|
| 1407 |
+
|
| 1408 |
+
|
| 1409 |
+
SAM_VISION_INPUTS_DOCSTRING = r"""
|
| 1410 |
+
Args:
|
| 1411 |
+
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1412 |
+
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
|
| 1413 |
+
details.
|
| 1414 |
+
output_attentions (`bool`, *optional*):
|
| 1415 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1416 |
+
tensors for more detail.
|
| 1417 |
+
output_hidden_states (`bool`, *optional*):
|
| 1418 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1419 |
+
more detail.
|
| 1420 |
+
return_dict (`bool`, *optional*):
|
| 1421 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1422 |
+
"""
|
| 1423 |
+
|
| 1424 |
+
|
| 1425 |
+
@add_start_docstrings(
|
| 1426 |
+
"""The vision model from Sam without any head or projection on top.""",
|
| 1427 |
+
SAM_START_DOCSTRING,
|
| 1428 |
+
)
|
| 1429 |
+
class TFSamVisionModel(TFSamPreTrainedModel):
|
| 1430 |
+
config_class = SamVisionConfig
|
| 1431 |
+
main_input_name = "pixel_values"
|
| 1432 |
+
|
| 1433 |
+
def __init__(self, config: SamVisionConfig, **kwargs):
|
| 1434 |
+
super().__init__(config, **kwargs)
|
| 1435 |
+
self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder")
|
| 1436 |
+
|
| 1437 |
+
def build(self, input_shape=None):
|
| 1438 |
+
if self.built:
|
| 1439 |
+
return
|
| 1440 |
+
self.built = True
|
| 1441 |
+
if getattr(self, "vision_encoder", None) is not None:
|
| 1442 |
+
with tf.name_scope(self.vision_encoder.name):
|
| 1443 |
+
self.vision_encoder.build(None)
|
| 1444 |
+
|
| 1445 |
+
def get_input_embeddings(self):
|
| 1446 |
+
return self.vision_encoder.patch_embed
|
| 1447 |
+
|
| 1448 |
+
@unpack_inputs
|
| 1449 |
+
@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
|
| 1450 |
+
@replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig)
|
| 1451 |
+
def call(
|
| 1452 |
+
self,
|
| 1453 |
+
pixel_values: TFModelInputType | None = None,
|
| 1454 |
+
output_attentions: bool | None = None,
|
| 1455 |
+
output_hidden_states: bool | None = None,
|
| 1456 |
+
return_dict: bool | None = None,
|
| 1457 |
+
training: bool = False,
|
| 1458 |
+
**kwargs,
|
| 1459 |
+
) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]:
|
| 1460 |
+
r"""
|
| 1461 |
+
Returns:
|
| 1462 |
+
|
| 1463 |
+
"""
|
| 1464 |
+
return self.vision_encoder(
|
| 1465 |
+
pixel_values,
|
| 1466 |
+
output_attentions=output_attentions,
|
| 1467 |
+
output_hidden_states=output_hidden_states,
|
| 1468 |
+
return_dict=return_dict,
|
| 1469 |
+
training=training,
|
| 1470 |
+
)
|
| 1471 |
+
|
| 1472 |
+
|
| 1473 |
+
@add_start_docstrings(
|
| 1474 |
+
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
|
| 1475 |
+
" optional 2D location and bounding boxes.",
|
| 1476 |
+
SAM_START_DOCSTRING,
|
| 1477 |
+
)
|
| 1478 |
+
class TFSamModel(TFSamPreTrainedModel):
|
| 1479 |
+
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
|
| 1480 |
+
|
| 1481 |
+
def __init__(self, config, **kwargs):
|
| 1482 |
+
super().__init__(config, **kwargs)
|
| 1483 |
+
self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding")
|
| 1484 |
+
|
| 1485 |
+
self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder")
|
| 1486 |
+
self.prompt_encoder = TFSamPromptEncoder(
|
| 1487 |
+
config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder"
|
| 1488 |
+
)
|
| 1489 |
+
self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder")
|
| 1490 |
+
self.config = config
|
| 1491 |
+
|
| 1492 |
+
def get_input_embeddings(self):
|
| 1493 |
+
return self.vision_encoder.get_input_embeddings()
|
| 1494 |
+
|
| 1495 |
+
def get_image_wide_positional_embeddings(self):
|
| 1496 |
+
size = self.config.prompt_encoder_config.image_embedding_size
|
| 1497 |
+
grid = tf.ones((size, size))
|
| 1498 |
+
y_embed = tf.math.cumsum(grid, axis=0) - 0.5
|
| 1499 |
+
x_embed = tf.math.cumsum(grid, axis=1) - 0.5
|
| 1500 |
+
y_embed = y_embed / size
|
| 1501 |
+
x_embed = x_embed / size
|
| 1502 |
+
|
| 1503 |
+
positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1))
|
| 1504 |
+
return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width
|
| 1505 |
+
|
| 1506 |
+
def get_image_embeddings(
|
| 1507 |
+
self,
|
| 1508 |
+
pixel_values,
|
| 1509 |
+
output_attentions: Optional[bool] = None,
|
| 1510 |
+
output_hidden_states: Optional[bool] = None,
|
| 1511 |
+
return_dict: Optional[bool] = None,
|
| 1512 |
+
):
|
| 1513 |
+
r"""
|
| 1514 |
+
Returns the image embeddings by passing the pixel values through the vision encoder.
|
| 1515 |
+
|
| 1516 |
+
Args:
|
| 1517 |
+
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1518 |
+
Input pixel values
|
| 1519 |
+
output_attentions (`bool`, *optional*):
|
| 1520 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 1521 |
+
output_hidden_states (`bool`, *optional*):
|
| 1522 |
+
Whether or not to return the hidden states of all layers.
|
| 1523 |
+
return_dict (`bool`, *optional*):
|
| 1524 |
+
Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.
|
| 1525 |
+
|
| 1526 |
+
"""
|
| 1527 |
+
vision_output = self.vision_encoder(
|
| 1528 |
+
pixel_values,
|
| 1529 |
+
output_attentions=output_attentions,
|
| 1530 |
+
output_hidden_states=output_hidden_states,
|
| 1531 |
+
return_dict=return_dict,
|
| 1532 |
+
)
|
| 1533 |
+
image_embeddings = vision_output[0]
|
| 1534 |
+
return image_embeddings
|
| 1535 |
+
|
| 1536 |
+
def get_prompt_embeddings(
|
| 1537 |
+
self,
|
| 1538 |
+
input_points: tf.Tensor | None = None,
|
| 1539 |
+
input_labels: tf.Tensor | None = None,
|
| 1540 |
+
input_boxes: tf.Tensor | None = None,
|
| 1541 |
+
input_masks: tf.Tensor | None = None,
|
| 1542 |
+
):
|
| 1543 |
+
r"""
|
| 1544 |
+
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
|
| 1545 |
+
|
| 1546 |
+
Args:
|
| 1547 |
+
input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
|
| 1548 |
+
Optional input points for the prompt encoder. The padding of the point is automatically done by the
|
| 1549 |
+
processor. `point_batch_size` refers to the number of masks that we want the model to predict per
|
| 1550 |
+
point. The model will output `point_batch_size` times 3 masks in total.
|
| 1551 |
+
input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
|
| 1552 |
+
Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
|
| 1553 |
+
processor, or can be fed by the user.
|
| 1554 |
+
input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
|
| 1555 |
+
Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
|
| 1556 |
+
processor. users can also pass manually the input boxes.
|
| 1557 |
+
input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
|
| 1558 |
+
Optional input masks for the prompt encoder.
|
| 1559 |
+
"""
|
| 1560 |
+
prompt_output = self.prompt_encoder(
|
| 1561 |
+
input_points=input_points,
|
| 1562 |
+
input_labels=input_labels,
|
| 1563 |
+
input_boxes=input_boxes,
|
| 1564 |
+
input_masks=input_masks,
|
| 1565 |
+
)
|
| 1566 |
+
return prompt_output
|
| 1567 |
+
|
| 1568 |
+
@unpack_inputs
|
| 1569 |
+
@add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
|
| 1570 |
+
def call(
|
| 1571 |
+
self,
|
| 1572 |
+
pixel_values: TFModelInputType | None = None,
|
| 1573 |
+
input_points: tf.Tensor | None = None,
|
| 1574 |
+
input_labels: tf.Tensor | None = None,
|
| 1575 |
+
input_boxes: tf.Tensor | None = None,
|
| 1576 |
+
input_masks: tf.Tensor | None = None,
|
| 1577 |
+
image_embeddings: tf.Tensor | None = None,
|
| 1578 |
+
multimask_output: bool = True,
|
| 1579 |
+
output_attentions: bool | None = None,
|
| 1580 |
+
output_hidden_states: bool | None = None,
|
| 1581 |
+
return_dict: bool | None = None,
|
| 1582 |
+
training: bool = False,
|
| 1583 |
+
**kwargs,
|
| 1584 |
+
) -> TFSamImageSegmentationOutput | Tuple[tf.Tensor]:
|
| 1585 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1586 |
+
output_hidden_states = (
|
| 1587 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1588 |
+
)
|
| 1589 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1590 |
+
|
| 1591 |
+
if pixel_values is None and image_embeddings is None:
|
| 1592 |
+
raise ValueError("Either pixel_values or image_embeddings must be provided.")
|
| 1593 |
+
|
| 1594 |
+
if pixel_values is not None and image_embeddings is not None:
|
| 1595 |
+
raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
|
| 1596 |
+
|
| 1597 |
+
if input_points is not None and len(input_points.shape) != 4:
|
| 1598 |
+
raise ValueError(
|
| 1599 |
+
"The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
|
| 1600 |
+
" got {}.".format(input_points.shape),
|
| 1601 |
+
)
|
| 1602 |
+
if input_boxes is not None and len(input_boxes.shape) != 3:
|
| 1603 |
+
raise ValueError(
|
| 1604 |
+
"The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
|
| 1605 |
+
" got {}.".format(input_boxes.shape),
|
| 1606 |
+
)
|
| 1607 |
+
if input_points is not None and input_boxes is not None:
|
| 1608 |
+
point_batch_size = shape_list(input_points)[1]
|
| 1609 |
+
box_batch_size = shape_list(input_boxes)[1]
|
| 1610 |
+
if point_batch_size != box_batch_size:
|
| 1611 |
+
raise ValueError(
|
| 1612 |
+
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
|
| 1613 |
+
point_batch_size, box_batch_size
|
| 1614 |
+
)
|
| 1615 |
+
)
|
| 1616 |
+
if pixel_values is not None:
|
| 1617 |
+
# Ensures that later checks pass even with an all-None shape from the serving signature
|
| 1618 |
+
pixel_values = tf.ensure_shape(
|
| 1619 |
+
pixel_values,
|
| 1620 |
+
[
|
| 1621 |
+
None,
|
| 1622 |
+
self.config.vision_config.num_channels,
|
| 1623 |
+
self.config.vision_config.image_size,
|
| 1624 |
+
self.config.vision_config.image_size,
|
| 1625 |
+
],
|
| 1626 |
+
)
|
| 1627 |
+
image_positional_embeddings = self.get_image_wide_positional_embeddings()
|
| 1628 |
+
# repeat with batch size
|
| 1629 |
+
batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0]
|
| 1630 |
+
image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0)
|
| 1631 |
+
|
| 1632 |
+
vision_attentions = None
|
| 1633 |
+
vision_hidden_states = None
|
| 1634 |
+
|
| 1635 |
+
if pixel_values is not None:
|
| 1636 |
+
vision_outputs = self.vision_encoder(
|
| 1637 |
+
pixel_values,
|
| 1638 |
+
output_attentions=output_attentions,
|
| 1639 |
+
output_hidden_states=output_hidden_states,
|
| 1640 |
+
return_dict=True,
|
| 1641 |
+
training=training,
|
| 1642 |
+
)
|
| 1643 |
+
image_embeddings = vision_outputs["last_hidden_state"]
|
| 1644 |
+
|
| 1645 |
+
if output_hidden_states:
|
| 1646 |
+
vision_hidden_states = vision_outputs["hidden_states"]
|
| 1647 |
+
if output_attentions:
|
| 1648 |
+
vision_attentions = vision_outputs["attentions"]
|
| 1649 |
+
|
| 1650 |
+
if input_points is not None and input_labels is None:
|
| 1651 |
+
input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32)
|
| 1652 |
+
|
| 1653 |
+
if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
|
| 1654 |
+
raise ValueError(
|
| 1655 |
+
"The batch size of the image embeddings and the input points must be the same. ",
|
| 1656 |
+
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
|
| 1657 |
+
" if you want to pass multiple points for the same image, make sure that you passed ",
|
| 1658 |
+
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
|
| 1659 |
+
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
| 1660 |
+
)
|
| 1661 |
+
|
| 1662 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 1663 |
+
batch_size=shape_list(image_embeddings)[0],
|
| 1664 |
+
input_points=input_points,
|
| 1665 |
+
input_labels=input_labels,
|
| 1666 |
+
input_boxes=input_boxes,
|
| 1667 |
+
input_masks=input_masks,
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
|
| 1671 |
+
image_embeddings=image_embeddings,
|
| 1672 |
+
image_positional_embeddings=image_positional_embeddings,
|
| 1673 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 1674 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 1675 |
+
multimask_output=multimask_output,
|
| 1676 |
+
output_attentions=output_attentions,
|
| 1677 |
+
)
|
| 1678 |
+
|
| 1679 |
+
if not return_dict:
|
| 1680 |
+
output = (iou_predictions, low_res_masks)
|
| 1681 |
+
if output_hidden_states:
|
| 1682 |
+
output = output + (vision_hidden_states,)
|
| 1683 |
+
|
| 1684 |
+
if output_attentions:
|
| 1685 |
+
output = output + (vision_attentions, mask_decoder_attentions)
|
| 1686 |
+
return output
|
| 1687 |
+
|
| 1688 |
+
return TFSamImageSegmentationOutput(
|
| 1689 |
+
iou_scores=iou_predictions,
|
| 1690 |
+
pred_masks=low_res_masks,
|
| 1691 |
+
vision_hidden_states=vision_hidden_states,
|
| 1692 |
+
vision_attentions=vision_attentions,
|
| 1693 |
+
mask_decoder_attentions=mask_decoder_attentions,
|
| 1694 |
+
)
|
| 1695 |
+
|
| 1696 |
+
def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput:
|
| 1697 |
+
hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None
|
| 1698 |
+
attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None
|
| 1699 |
+
|
| 1700 |
+
return TFSamImageSegmentationOutput(
|
| 1701 |
+
iou_scores=output.iou_scores,
|
| 1702 |
+
pred_masks=output.pred_masks,
|
| 1703 |
+
vision_hidden_states=hs if self.config.output_hidden_states else None,
|
| 1704 |
+
vision_attentions=attns if self.config.output_attentions else None,
|
| 1705 |
+
mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
|
| 1706 |
+
)
|
| 1707 |
+
|
| 1708 |
+
def build(self, input_shape=None):
|
| 1709 |
+
if self.built:
|
| 1710 |
+
return
|
| 1711 |
+
self.built = True
|
| 1712 |
+
if getattr(self, "shared_image_embedding", None) is not None:
|
| 1713 |
+
with tf.name_scope(self.shared_image_embedding.name):
|
| 1714 |
+
self.shared_image_embedding.build(None)
|
| 1715 |
+
if getattr(self, "vision_encoder", None) is not None:
|
| 1716 |
+
with tf.name_scope(self.vision_encoder.name):
|
| 1717 |
+
self.vision_encoder.build(None)
|
| 1718 |
+
if getattr(self, "prompt_encoder", None) is not None:
|
| 1719 |
+
with tf.name_scope(self.prompt_encoder.name):
|
| 1720 |
+
self.prompt_encoder.build(None)
|
| 1721 |
+
if getattr(self, "mask_decoder", None) is not None:
|
| 1722 |
+
with tf.name_scope(self.mask_decoder.name):
|
| 1723 |
+
self.mask_decoder.build(None)
|
| 1724 |
+
|
| 1725 |
+
|
| 1726 |
+
__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/seamless_m4t/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_seamless_m4t import *
|
| 22 |
+
from .feature_extraction_seamless_m4t import *
|
| 23 |
+
from .modeling_seamless_m4t import *
|
| 24 |
+
from .processing_seamless_m4t import *
|
| 25 |
+
from .tokenization_seamless_m4t import *
|
| 26 |
+
from .tokenization_seamless_m4t_fast import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/seamless_m4t/configuration_seamless_m4t.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""SeamlessM4T model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SeamlessM4TConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`~SeamlessM4TModel`]. It is used to instantiate an
|
| 27 |
+
SeamlessM4T model according to the specified arguments, defining the model architecture. Instantiating a
|
| 28 |
+
configuration with the defaults will yield a similar configuration to that of the SeamlessM4T
|
| 29 |
+
["facebook/hf-seamless-m4t-medium"](https://huggingface.co/"facebook/hf-seamless-m4t-medium") architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 256102):
|
| 37 |
+
Vocabulary size of the SeamlessM4T model. Defines the number of different tokens that can be represented by
|
| 38 |
+
the `inputs_ids` passed when calling [`~SeamlessM4TModel`], [`~SeamlessM4TForTextToSpeech`] or
|
| 39 |
+
[`~SeamlessM4TForTextToText`].
|
| 40 |
+
t2u_vocab_size (`int`, *optional*, defaults to 10082):
|
| 41 |
+
Unit vocabulary size of the SeamlessM4T model. Defines the number of different unit tokens that can be
|
| 42 |
+
represented by the `inputs_ids` passed when calling the Text-To-Units sub-model of [`~SeamlessM4TModel`],
|
| 43 |
+
[`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`].
|
| 44 |
+
|
| 45 |
+
> Parameters shared across sub-models
|
| 46 |
+
|
| 47 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
| 48 |
+
Dimensionality of the "intermediate" layers in the architecture.
|
| 49 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 50 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 51 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 52 |
+
The epsilon used by the layer normalization layers.
|
| 53 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 54 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 55 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
| 56 |
+
The maximum sequence length that this model text encoder and decoder might ever be used with. Typically set
|
| 57 |
+
this to something large just in case (e.g., 512 or 1024 or 2048).
|
| 58 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether the model is used as an encoder/decoder or not.
|
| 60 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.05):
|
| 61 |
+
The LayerDrop probability for the encoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 62 |
+
for more details.
|
| 63 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.05):
|
| 64 |
+
The LayerDrop probability for the decoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 65 |
+
for more details.
|
| 66 |
+
activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
|
| 67 |
+
The non-linear activation function (function or string) in the decoder and feed-forward layers. If string,
|
| 68 |
+
`"gelu"`, `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
|
| 69 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 70 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, decoder, and pooler.
|
| 71 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 72 |
+
The dropout probability for all attention layers.
|
| 73 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 74 |
+
The dropout probability for all activation layers in the model.
|
| 75 |
+
scale_embedding (`bool`, *optional*, defaults to `True`):
|
| 76 |
+
Scale embeddings by diving by sqrt(d_model).
|
| 77 |
+
|
| 78 |
+
> Text encoder and text decoder specific parameters
|
| 79 |
+
|
| 80 |
+
encoder_layers (`int`, *optional*, defaults to 24):
|
| 81 |
+
Number of hidden layers in the Transformer text encoder.
|
| 82 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 8192):
|
| 83 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text encoder.
|
| 84 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 85 |
+
Number of attention heads for each attention layer in the Transformer text encoder.
|
| 86 |
+
decoder_layers (`int`, *optional*, defaults to 24):
|
| 87 |
+
Number of hidden layers in the Transformer text decoder.
|
| 88 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 8192):
|
| 89 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text decoder.
|
| 90 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 91 |
+
Number of attention heads for each attention layer in the Transformer text decoder.
|
| 92 |
+
decoder_start_token_id (`int`, *optional*, defaults to 3):
|
| 93 |
+
If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only
|
| 94 |
+
applied in the text decoder.
|
| 95 |
+
max_new_tokens (`int`, *optional*, defaults to 256):
|
| 96 |
+
The maximum numbers of text tokens to generate, ignoring the number of tokens in the prompt.
|
| 97 |
+
pad_token_id (`int`, *optional*, defaults to 0):
|
| 98 |
+
The id of the _padding_ text token. Only applied to the text-decoder model.
|
| 99 |
+
bos_token_id (`int`, *optional*, defaults to 2):
|
| 100 |
+
The id of the _beginning-of-stream_ text token. Only applied to the text-decoder model.
|
| 101 |
+
eos_token_id (`int`, *optional*, defaults to 3):
|
| 102 |
+
The id of the _end-of-stream_ text token. Only applied to the text-decoder model.
|
| 103 |
+
|
| 104 |
+
> Speech encoder specific parameters
|
| 105 |
+
|
| 106 |
+
speech_encoder_layers (`int`, *optional*, defaults to 24):
|
| 107 |
+
Number of hidden layers in the Transformer speech encoder.
|
| 108 |
+
speech_encoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 109 |
+
Number of attention heads for each attention layer in the Transformer speech encoder.
|
| 110 |
+
speech_encoder_intermediate_size (`int`, *optional*, defaults to 4096):
|
| 111 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer speech encoder.
|
| 112 |
+
speech_encoder_hidden_act (`str` or `function`, *optional*, defaults to `"swish"`):
|
| 113 |
+
The non-linear activation function (function or string) in the speech encoder. If string, `"gelu"`,
|
| 114 |
+
`"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
|
| 115 |
+
speech_encoder_dropout (`float`, *optional*, defaults to 0.0):
|
| 116 |
+
The dropout probability for all layers in the speech encoder.
|
| 117 |
+
add_adapter (`bool`, *optional*, defaults to `True`):
|
| 118 |
+
Add an adapter layer on top of the speech encoder.
|
| 119 |
+
speech_encoder_layerdrop (`float`, *optional*, defaults to 0.1):
|
| 120 |
+
The LayerDrop probability for the speech encoder. See the [LayerDrop paper](see
|
| 121 |
+
https://arxiv.org/abs/1909.11556) for more details.
|
| 122 |
+
feature_projection_input_dim (`int`, *optional*, defaults to 160):
|
| 123 |
+
Input dimension of the input feature projection of the speech encoder, i.e the dimension after processing
|
| 124 |
+
input audios with [`SeamlessM4TFeatureExtractor`].
|
| 125 |
+
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
| 126 |
+
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
| 127 |
+
embeddings layer of the speech encoder.
|
| 128 |
+
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
| 129 |
+
Number of groups of 1D convolutional positional embeddings layer of the speech encoder.
|
| 130 |
+
adaptor_kernel_size (`int`, *optional*, defaults to 8):
|
| 131 |
+
Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
| 132 |
+
adaptor_stride (`int`, *optional*, defaults to 8):
|
| 133 |
+
Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
| 134 |
+
adaptor_dropout (`float`, *optional*, defaults to 0.1):
|
| 135 |
+
The dropout probability for all layers in the speech adapter.
|
| 136 |
+
num_adapter_layers (`int`, *optional*, defaults to 1):
|
| 137 |
+
Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
|
| 138 |
+
True`.
|
| 139 |
+
position_embeddings_type (`str`, *optional*, defaults to `"relative"`):
|
| 140 |
+
Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left
|
| 141 |
+
`None` no relative position embedding is applied. Only applied to the speech encoder.
|
| 142 |
+
rotary_embedding_base (`int`, *optional*, defaults to 10000):
|
| 143 |
+
If `"rotary"` position embeddings are used, defines the size of the embedding base. Only applied to the
|
| 144 |
+
speech encoder.
|
| 145 |
+
max_source_positions (`int`, *optional*, defaults to 4096):
|
| 146 |
+
if `"relative"` position embeddings are used, defines the maximum source input positions. Only applied to
|
| 147 |
+
the speech encoder.
|
| 148 |
+
conv_depthwise_kernel_size (`int`, *optional*, defaults to 31):
|
| 149 |
+
Kernel size of convolutional depthwise 1D layer in Conformer blocks. Only applied to the speech encoder.
|
| 150 |
+
|
| 151 |
+
> Text-To-Unit (t2u) model specific parameters
|
| 152 |
+
|
| 153 |
+
t2u_bos_token_id (`int`, *optional*, defaults to 0):
|
| 154 |
+
The id of the _beginning-of-stream_ unit token. Only applied to the text-to-unit seq2seq model.
|
| 155 |
+
t2u_pad_token_id (`int`, *optional*, defaults to 1):
|
| 156 |
+
The id of the _padding_ unit token. Only applied to the text-to-unit seq2seq model.
|
| 157 |
+
t2u_eos_token_id (`int`, *optional*, defaults to 2):
|
| 158 |
+
The id of the _end-of-stream_ unit token. Only applied to the text-to-unit seq2seq model.
|
| 159 |
+
t2u_decoder_start_token_id (`int`, *optional*, defaults to 2):
|
| 160 |
+
If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only
|
| 161 |
+
applied to the text-to-unit seq2seq model.
|
| 162 |
+
t2u_max_new_tokens (`int`, *optional*, defaults to 1024):
|
| 163 |
+
The maximum numbers of unit tokens to generate, ignoring the number of tokens in the prompt. Only applied
|
| 164 |
+
to the text-to-unit seq2seq model.
|
| 165 |
+
t2u_encoder_layers (`int`, *optional*, defaults to 6):
|
| 166 |
+
Number of hidden layers in the Transformer text-to-unit encoder.
|
| 167 |
+
t2u_encoder_ffn_dim (`int`, *optional*, defaults to 8192):
|
| 168 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit encoder.
|
| 169 |
+
t2u_encoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 170 |
+
Number of attention heads for each attention layer in the Transformer text-to-unit encoder.
|
| 171 |
+
t2u_decoder_layers (`int`, *optional*, defaults to 6):
|
| 172 |
+
Number of hidden layers in the Transformer text-to-unit decoder.
|
| 173 |
+
t2u_decoder_ffn_dim (`int`, *optional*, defaults to 8192):
|
| 174 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit decoder.
|
| 175 |
+
t2u_decoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 176 |
+
Number of attention heads for each attention layer in the Transformer text-to-unit decoder.
|
| 177 |
+
t2u_max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 178 |
+
The maximum sequence length that this model text-to-unit component might ever be used with. Typically set
|
| 179 |
+
this to something large just in case (e.g., 512 or 1024 or 2048).
|
| 180 |
+
|
| 181 |
+
> Hifi-Gan Vocoder specific parameters
|
| 182 |
+
|
| 183 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 184 |
+
The sampling rate at which the output audio will be generated, expressed in hertz (Hz).
|
| 185 |
+
upsample_initial_channel (`int`, *optional*, defaults to 512):
|
| 186 |
+
The number of input channels into the hifi-gan upsampling network. Applies to the vocoder only.
|
| 187 |
+
upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[5, 4, 4, 2, 2]`):
|
| 188 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the vocoder upsampling network.
|
| 189 |
+
The length of *upsample_rates* defines the number of convolutional layers and has to match the length of
|
| 190 |
+
*upsample_kernel_sizes*. Applies to the vocoder only.
|
| 191 |
+
upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[11, 8, 8, 4, 4]`):
|
| 192 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the vocoder upsampling
|
| 193 |
+
network. The length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match
|
| 194 |
+
the length of *upsample_rates*. Applies to the vocoder only.
|
| 195 |
+
resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`):
|
| 196 |
+
A tuple of integers defining the kernel sizes of the vocoder 1D convolutional layers in the multi-receptive
|
| 197 |
+
field fusion (MRF) module. Applies to the vocoder only.
|
| 198 |
+
resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
|
| 199 |
+
A nested tuple of integers defining the dilation rates of the vocoder dilated 1D convolutional layers in
|
| 200 |
+
the multi-receptive field fusion (MRF) module. Applies to the vocoder only.
|
| 201 |
+
leaky_relu_slope (`float`, *optional*, defaults to 0.1):
|
| 202 |
+
The angle of the negative slope used by the leaky ReLU activation in the vocoder. Applies to the vocoder
|
| 203 |
+
only.
|
| 204 |
+
unit_hifi_gan_vocab_size (`int`, *optional*, defaults to 10000):
|
| 205 |
+
Vocabulary size of the SeamlessM4T vocoder. Defines the number of different unit tokens that can be
|
| 206 |
+
represented by the `inputs_ids` passed when calling the vocoder of [`~SeamlessM4TModel`],
|
| 207 |
+
[`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`].
|
| 208 |
+
unit_embed_dim (`int`, *optional*, defaults to 1280):
|
| 209 |
+
The projection dimension of the input ids given to the hifi-gan vocoder. Applies to the vocoder only.
|
| 210 |
+
lang_embed_dim (`int`, *optional*, defaults to 256):
|
| 211 |
+
The projection dimension of the target language given to the hifi-gan vocoder. Applies to the vocoder only.
|
| 212 |
+
spkr_embed_dim (`int`, *optional*, defaults to 256):
|
| 213 |
+
The projection dimension of the speaker id given to the hifi-gan vocoder. Applies to the vocoder only.
|
| 214 |
+
vocoder_num_langs (`int`, *optional*, defaults to 36):
|
| 215 |
+
Number of langs supported by the vocoder. Might be different from `t2u_num_langs`.
|
| 216 |
+
vocoder_num_spkrs (`int`, *optional*, defaults to 200):
|
| 217 |
+
Number of speakers supported by the vocoder.
|
| 218 |
+
variance_predictor_kernel_size (`int`, *optional*, defaults to 3):
|
| 219 |
+
Kernel size of the duration predictor. Applies to the vocoder only.
|
| 220 |
+
var_pred_dropout (`float`, *optional*, defaults to 0.5):
|
| 221 |
+
The dropout probability of the duration predictor. Applies to the vocoder only.
|
| 222 |
+
vocoder_offset (`int`, *optional*, defaults to 4):
|
| 223 |
+
Offset the unit token ids by this number to account for symbol tokens. Applies to the vocoder only.
|
| 224 |
+
|
| 225 |
+
```python
|
| 226 |
+
>>> from transformers import SeamlessM4TModel, SeamlessM4TConfig
|
| 227 |
+
|
| 228 |
+
>>> # Initializing a SeamlessM4T "facebook/hf-seamless-m4t-medium" style configuration
|
| 229 |
+
>>> configuration = SeamlessM4TConfig()
|
| 230 |
+
|
| 231 |
+
>>> # Initializing a model from the "facebook/hf-seamless-m4t-medium" style configuration
|
| 232 |
+
>>> model = SeamlessM4TModel(configuration)
|
| 233 |
+
|
| 234 |
+
>>> # Accessing the model configuration
|
| 235 |
+
>>> configuration = model.config
|
| 236 |
+
```"""
|
| 237 |
+
|
| 238 |
+
model_type = "seamless_m4t"
|
| 239 |
+
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
vocab_size=256102,
|
| 243 |
+
t2u_vocab_size=10082,
|
| 244 |
+
# shared config
|
| 245 |
+
hidden_size=1024,
|
| 246 |
+
initializer_range=0.02,
|
| 247 |
+
layer_norm_eps=1e-5,
|
| 248 |
+
use_cache=True,
|
| 249 |
+
max_position_embeddings=1024,
|
| 250 |
+
is_encoder_decoder=True,
|
| 251 |
+
encoder_layerdrop=0.05,
|
| 252 |
+
decoder_layerdrop=0.05,
|
| 253 |
+
activation_function="relu",
|
| 254 |
+
dropout=0.1,
|
| 255 |
+
attention_dropout=0.1,
|
| 256 |
+
activation_dropout=0.0,
|
| 257 |
+
scale_embedding=True,
|
| 258 |
+
# text encoder|decoder
|
| 259 |
+
encoder_layers=24,
|
| 260 |
+
encoder_ffn_dim=8192,
|
| 261 |
+
encoder_attention_heads=16,
|
| 262 |
+
decoder_layers=24,
|
| 263 |
+
decoder_ffn_dim=8192,
|
| 264 |
+
decoder_attention_heads=16,
|
| 265 |
+
decoder_start_token_id=3,
|
| 266 |
+
max_new_tokens=256,
|
| 267 |
+
pad_token_id=0,
|
| 268 |
+
bos_token_id=2,
|
| 269 |
+
eos_token_id=3,
|
| 270 |
+
# speech_encoder
|
| 271 |
+
speech_encoder_layers=24,
|
| 272 |
+
speech_encoder_attention_heads=16,
|
| 273 |
+
speech_encoder_intermediate_size=4096,
|
| 274 |
+
speech_encoder_hidden_act="swish",
|
| 275 |
+
speech_encoder_dropout=0.0,
|
| 276 |
+
add_adapter=True,
|
| 277 |
+
speech_encoder_layerdrop=0.1,
|
| 278 |
+
feature_projection_input_dim=160,
|
| 279 |
+
num_conv_pos_embeddings=128,
|
| 280 |
+
num_conv_pos_embedding_groups=16,
|
| 281 |
+
adaptor_kernel_size=8,
|
| 282 |
+
adaptor_stride=8,
|
| 283 |
+
adaptor_dropout=0.1,
|
| 284 |
+
num_adapter_layers=1,
|
| 285 |
+
position_embeddings_type="relative",
|
| 286 |
+
rotary_embedding_base=10000,
|
| 287 |
+
max_source_positions=4096,
|
| 288 |
+
conv_depthwise_kernel_size=31,
|
| 289 |
+
# t2u config
|
| 290 |
+
t2u_bos_token_id=0,
|
| 291 |
+
t2u_pad_token_id=1,
|
| 292 |
+
t2u_eos_token_id=2,
|
| 293 |
+
t2u_decoder_start_token_id=2,
|
| 294 |
+
t2u_max_new_tokens=1024,
|
| 295 |
+
t2u_encoder_layers=6,
|
| 296 |
+
t2u_encoder_ffn_dim=8192,
|
| 297 |
+
t2u_encoder_attention_heads=16,
|
| 298 |
+
t2u_decoder_layers=6,
|
| 299 |
+
t2u_decoder_ffn_dim=8192,
|
| 300 |
+
t2u_decoder_attention_heads=16,
|
| 301 |
+
t2u_max_position_embeddings=2048,
|
| 302 |
+
# hifi-gan vocoder config
|
| 303 |
+
sampling_rate=16000,
|
| 304 |
+
upsample_initial_channel=512,
|
| 305 |
+
upsample_rates=[5, 4, 4, 2, 2],
|
| 306 |
+
upsample_kernel_sizes=[11, 8, 8, 4, 4],
|
| 307 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 308 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 309 |
+
leaky_relu_slope=0.1,
|
| 310 |
+
# specific to Code Hifi-Gan
|
| 311 |
+
unit_hifi_gan_vocab_size=10000,
|
| 312 |
+
unit_embed_dim=1280,
|
| 313 |
+
lang_embed_dim=256,
|
| 314 |
+
spkr_embed_dim=256,
|
| 315 |
+
vocoder_num_langs=36,
|
| 316 |
+
vocoder_num_spkrs=200,
|
| 317 |
+
variance_predictor_kernel_size=3,
|
| 318 |
+
var_pred_dropout=0.5,
|
| 319 |
+
vocoder_offset=4,
|
| 320 |
+
**kwargs,
|
| 321 |
+
):
|
| 322 |
+
# overall_config
|
| 323 |
+
self.vocab_size = vocab_size
|
| 324 |
+
self.t2u_vocab_size = t2u_vocab_size
|
| 325 |
+
self.hidden_size = hidden_size
|
| 326 |
+
self.initializer_range = initializer_range
|
| 327 |
+
self.layer_norm_eps = layer_norm_eps
|
| 328 |
+
self.max_position_embeddings = max_position_embeddings
|
| 329 |
+
self.use_cache = use_cache
|
| 330 |
+
self.max_new_tokens = max_new_tokens
|
| 331 |
+
self.encoder_layerdrop = encoder_layerdrop
|
| 332 |
+
self.decoder_layerdrop = decoder_layerdrop
|
| 333 |
+
self.activation_function = activation_function
|
| 334 |
+
self.dropout = dropout
|
| 335 |
+
self.attention_dropout = attention_dropout
|
| 336 |
+
self.activation_dropout = activation_dropout
|
| 337 |
+
self.scale_embedding = scale_embedding
|
| 338 |
+
# for proper config init
|
| 339 |
+
self.num_attention_heads = decoder_attention_heads
|
| 340 |
+
self.num_hidden_layers = decoder_layers
|
| 341 |
+
|
| 342 |
+
# text|unit encoder|decoder
|
| 343 |
+
self.encoder_layers = encoder_layers
|
| 344 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 345 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 346 |
+
self.decoder_layers = decoder_layers
|
| 347 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 348 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 349 |
+
|
| 350 |
+
# speech_encoder
|
| 351 |
+
self.speech_encoder_layers = speech_encoder_layers
|
| 352 |
+
self.speech_encoder_hidden_act = speech_encoder_hidden_act
|
| 353 |
+
self.speech_encoder_dropout = speech_encoder_dropout
|
| 354 |
+
self.speech_encoder_attention_heads = speech_encoder_attention_heads
|
| 355 |
+
self.speech_encoder_layerdrop = speech_encoder_layerdrop
|
| 356 |
+
self.speech_encoder_intermediate_size = speech_encoder_intermediate_size
|
| 357 |
+
self.feature_projection_input_dim = feature_projection_input_dim
|
| 358 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
| 359 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
| 360 |
+
self.adaptor_kernel_size = adaptor_kernel_size
|
| 361 |
+
self.adaptor_stride = adaptor_stride
|
| 362 |
+
self.adaptor_dropout = adaptor_dropout
|
| 363 |
+
self.num_adapter_layers = num_adapter_layers
|
| 364 |
+
self.position_embeddings_type = position_embeddings_type
|
| 365 |
+
self.rotary_embedding_base = rotary_embedding_base
|
| 366 |
+
self.max_source_positions = max_source_positions
|
| 367 |
+
self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
|
| 368 |
+
self.add_adapter = add_adapter
|
| 369 |
+
|
| 370 |
+
# t2u config
|
| 371 |
+
self.t2u_bos_token_id = t2u_bos_token_id
|
| 372 |
+
self.t2u_pad_token_id = t2u_pad_token_id
|
| 373 |
+
self.t2u_eos_token_id = t2u_eos_token_id
|
| 374 |
+
self.t2u_decoder_start_token_id = t2u_decoder_start_token_id
|
| 375 |
+
self.t2u_max_new_tokens = t2u_max_new_tokens
|
| 376 |
+
self.t2u_encoder_layers = t2u_encoder_layers
|
| 377 |
+
self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim
|
| 378 |
+
self.t2u_encoder_attention_heads = t2u_encoder_attention_heads
|
| 379 |
+
self.t2u_decoder_layers = t2u_decoder_layers
|
| 380 |
+
self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim
|
| 381 |
+
self.t2u_decoder_attention_heads = t2u_decoder_attention_heads
|
| 382 |
+
self.t2u_max_position_embeddings = t2u_max_position_embeddings
|
| 383 |
+
|
| 384 |
+
# hifi-gan vocoder config
|
| 385 |
+
# original parameters specific to Hifi-Gan
|
| 386 |
+
self.sampling_rate = sampling_rate
|
| 387 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 388 |
+
self.upsample_rates = upsample_rates
|
| 389 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 390 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 391 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 392 |
+
self.leaky_relu_slope = leaky_relu_slope
|
| 393 |
+
|
| 394 |
+
# specific to Code Hifi-Gan
|
| 395 |
+
self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size
|
| 396 |
+
self.unit_embed_dim = unit_embed_dim
|
| 397 |
+
self.lang_embed_dim = lang_embed_dim
|
| 398 |
+
self.spkr_embed_dim = spkr_embed_dim
|
| 399 |
+
self.vocoder_num_langs = vocoder_num_langs
|
| 400 |
+
self.vocoder_num_spkrs = vocoder_num_spkrs
|
| 401 |
+
self.variance_predictor_kernel_size = variance_predictor_kernel_size
|
| 402 |
+
self.var_pred_dropout = var_pred_dropout
|
| 403 |
+
self.vocoder_offset = vocoder_offset
|
| 404 |
+
|
| 405 |
+
super().__init__(
|
| 406 |
+
pad_token_id=pad_token_id,
|
| 407 |
+
bos_token_id=bos_token_id,
|
| 408 |
+
eos_token_id=eos_token_id,
|
| 409 |
+
decoder_start_token_id=decoder_start_token_id,
|
| 410 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 411 |
+
max_position_embeddings=max_position_embeddings,
|
| 412 |
+
**kwargs,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
__all__ = ["SeamlessM4TConfig"]
|
docs/transformers/build/lib/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Feature extractor class for SeamlessM4T
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from typing import List, Optional, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from ...utils import is_torch_available
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_torch_available():
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
| 30 |
+
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
| 31 |
+
from ...feature_extraction_utils import BatchFeature
|
| 32 |
+
from ...utils import PaddingStrategy, TensorType, logging
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
|
| 39 |
+
r"""
|
| 40 |
+
Constructs a SeamlessM4T feature extractor.
|
| 41 |
+
|
| 42 |
+
This feature extractor inherits from [`SequenceFeatureExtractor`] which contains most of the main methods. Users
|
| 43 |
+
should refer to this superclass for more information regarding those methods.
|
| 44 |
+
|
| 45 |
+
This class extracts mel-filter bank features from raw speech.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
feature_size (`int`, *optional*, defaults to 80):
|
| 49 |
+
The feature dimension of the extracted features.
|
| 50 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 51 |
+
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
|
| 52 |
+
num_mel_bins (`int`, *optional*, defaults to 80):
|
| 53 |
+
Number of Mel-frequency bins.
|
| 54 |
+
padding_value (`float`, *optional*, defaults to 0.0):
|
| 55 |
+
The value that is used to fill the padding vectors.
|
| 56 |
+
stride (`int`, *optional*, defaults to 2):
|
| 57 |
+
Stride used to reshape audios from shape (batch_size,num_frames,num_mel_bins) to
|
| 58 |
+
(batch_size,num_frames//stride,num_mel_bins*stride).
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
model_input_names = ["input_features", "attention_mask"]
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
feature_size=80,
|
| 66 |
+
sampling_rate=16000,
|
| 67 |
+
num_mel_bins=80,
|
| 68 |
+
padding_value=0.0,
|
| 69 |
+
stride=2,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
self.num_mel_bins = num_mel_bins
|
| 73 |
+
self.return_attention_mask = True
|
| 74 |
+
self.stride = stride
|
| 75 |
+
|
| 76 |
+
mel_filters = mel_filter_bank(
|
| 77 |
+
num_frequency_bins=257,
|
| 78 |
+
num_mel_filters=self.num_mel_bins,
|
| 79 |
+
min_frequency=20,
|
| 80 |
+
max_frequency=sampling_rate // 2,
|
| 81 |
+
sampling_rate=sampling_rate,
|
| 82 |
+
norm=None,
|
| 83 |
+
mel_scale="kaldi",
|
| 84 |
+
triangularize_in_mel_space=True,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.mel_filters = mel_filters
|
| 88 |
+
self.window = window_function(400, "povey", periodic=False)
|
| 89 |
+
|
| 90 |
+
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
|
| 94 |
+
def zero_mean_unit_var_norm(
|
| 95 |
+
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
|
| 96 |
+
) -> List[np.ndarray]:
|
| 97 |
+
"""
|
| 98 |
+
Every array in the list is normalized to have zero mean and unit variance
|
| 99 |
+
"""
|
| 100 |
+
if attention_mask is not None:
|
| 101 |
+
attention_mask = np.array(attention_mask, np.int32)
|
| 102 |
+
normed_input_values = []
|
| 103 |
+
|
| 104 |
+
for vector, length in zip(input_values, attention_mask.sum(-1)):
|
| 105 |
+
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
| 106 |
+
if length < normed_slice.shape[0]:
|
| 107 |
+
normed_slice[length:] = padding_value
|
| 108 |
+
|
| 109 |
+
normed_input_values.append(normed_slice)
|
| 110 |
+
else:
|
| 111 |
+
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
|
| 112 |
+
|
| 113 |
+
return normed_input_values
|
| 114 |
+
|
| 115 |
+
def _extract_fbank_features(
|
| 116 |
+
self,
|
| 117 |
+
waveform: np.ndarray,
|
| 118 |
+
) -> np.ndarray:
|
| 119 |
+
"""
|
| 120 |
+
Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
|
| 121 |
+
and hence the waveform should not be normalized before feature extraction.
|
| 122 |
+
"""
|
| 123 |
+
# by default, it extracts the left channel if stereo
|
| 124 |
+
if len(waveform.shape) == 2:
|
| 125 |
+
waveform = waveform[0]
|
| 126 |
+
|
| 127 |
+
waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers
|
| 128 |
+
features = spectrogram(
|
| 129 |
+
waveform,
|
| 130 |
+
self.window,
|
| 131 |
+
frame_length=400,
|
| 132 |
+
hop_length=160,
|
| 133 |
+
fft_length=512,
|
| 134 |
+
power=2.0,
|
| 135 |
+
center=False,
|
| 136 |
+
preemphasis=0.97,
|
| 137 |
+
mel_filters=self.mel_filters,
|
| 138 |
+
log_mel="log",
|
| 139 |
+
mel_floor=1.192092955078125e-07,
|
| 140 |
+
remove_dc_offset=True,
|
| 141 |
+
).T
|
| 142 |
+
return features
|
| 143 |
+
|
| 144 |
+
def __call__(
|
| 145 |
+
self,
|
| 146 |
+
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
| 147 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
| 148 |
+
pad_to_multiple_of: Optional[int] = 2,
|
| 149 |
+
max_length: Optional[int] = None,
|
| 150 |
+
truncation: bool = False,
|
| 151 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 152 |
+
sampling_rate: Optional[int] = None,
|
| 153 |
+
return_attention_mask: Optional[bool] = None,
|
| 154 |
+
do_normalize_per_mel_bins: Optional[bool] = True,
|
| 155 |
+
**kwargs,
|
| 156 |
+
) -> BatchFeature:
|
| 157 |
+
"""
|
| 158 |
+
Main method to featurize and prepare for the model one or several sequence(s).
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
raw_speech (`np.ndarray`, `torch.Tensor`, `List[float]`, `List[np.ndarray]`, `List[torch.Tensor]`,
|
| 162 |
+
`List[List[float]]`, `List[List[List[float]]]`):
|
| 163 |
+
The sequence or batch of sequences to be padded. Each sequence can be a numpy array,
|
| 164 |
+
a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors,
|
| 165 |
+
a list of list of float values or a list of a list of list of float values.
|
| 166 |
+
If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `List[float]`, `raw_speech` is
|
| 167 |
+
considered a single-channel, single-sample sound. In all other cases, the first dimension of
|
| 168 |
+
`raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `List[...]`,
|
| 169 |
+
corresponds to the number of samples in the batch, and the number of channels
|
| 170 |
+
(i.e. mono or stereo character) is derived from the other dimensions
|
| 171 |
+
(1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches).
|
| 172 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 173 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 174 |
+
index) among:
|
| 175 |
+
|
| 176 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 177 |
+
sequence if provided).
|
| 178 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 179 |
+
acceptable input length for the model if that argument is not provided.
|
| 180 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
| 181 |
+
lengths).
|
| 182 |
+
pad_to_multiple_of (`int`, *optional*, defaults to 2):
|
| 183 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 184 |
+
|
| 185 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 186 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
|
| 187 |
+
max_length (`int`, *optional*):
|
| 188 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 189 |
+
truncation (`bool`):
|
| 190 |
+
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
|
| 191 |
+
return_attention_mask (`bool`, *optional*):
|
| 192 |
+
Whether to return the attention mask. If left to the default, will return the attention mask according
|
| 193 |
+
to the specific feature_extractor's default.
|
| 194 |
+
|
| 195 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 196 |
+
|
| 197 |
+
<Tip>
|
| 198 |
+
|
| 199 |
+
For SeamlessM4T models, `attention_mask` should always be passed for batched inference, to avoid subtle
|
| 200 |
+
bugs.
|
| 201 |
+
|
| 202 |
+
</Tip>
|
| 203 |
+
|
| 204 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 205 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
| 206 |
+
|
| 207 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 208 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 209 |
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
| 210 |
+
sampling_rate (`int`, *optional*):
|
| 211 |
+
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
|
| 212 |
+
`sampling_rate` at the forward call to prevent silent errors.
|
| 213 |
+
do_normalize_per_mel_bins (`bool`, *optional*, defaults to `True`):
|
| 214 |
+
Whether or not to zero-mean unit-variance normalize the input per mel-channel.
|
| 215 |
+
kwargs (*optional*):
|
| 216 |
+
Remaining dictionary of keyword arguments that will be passed to the tokenizer or the feature
|
| 217 |
+
extractor.
|
| 218 |
+
"""
|
| 219 |
+
if sampling_rate is not None:
|
| 220 |
+
if sampling_rate != self.sampling_rate:
|
| 221 |
+
raise ValueError(
|
| 222 |
+
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
| 223 |
+
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
|
| 224 |
+
f" {self.sampling_rate} and not {sampling_rate}."
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
logger.warning(
|
| 228 |
+
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
| 229 |
+
"Failing to do so can result in silent errors that might be hard to debug."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
return_attention_mask = (
|
| 233 |
+
return_attention_mask if return_attention_mask is not None else self.return_attention_mask
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
| 237 |
+
if is_batched_numpy and len(raw_speech.shape) > 3:
|
| 238 |
+
raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}")
|
| 239 |
+
|
| 240 |
+
acceptable_types = (
|
| 241 |
+
(torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list)
|
| 242 |
+
)
|
| 243 |
+
is_batched = is_batched_numpy or (
|
| 244 |
+
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types))
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if is_batched:
|
| 248 |
+
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
|
| 249 |
+
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
| 250 |
+
raw_speech = np.asarray(raw_speech, dtype=np.float32)
|
| 251 |
+
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
|
| 252 |
+
raw_speech = raw_speech.astype(np.float32)
|
| 253 |
+
|
| 254 |
+
# always return batch
|
| 255 |
+
if not is_batched:
|
| 256 |
+
raw_speech = [raw_speech]
|
| 257 |
+
|
| 258 |
+
# extract fbank features
|
| 259 |
+
features = [self._extract_fbank_features(waveform) for waveform in raw_speech]
|
| 260 |
+
|
| 261 |
+
if do_normalize_per_mel_bins:
|
| 262 |
+
# torch defaults to ddof=1, and numpy defaults to ddof=0
|
| 263 |
+
features = [
|
| 264 |
+
(x - np.expand_dims(x.mean(0), 0)) / np.sqrt(np.expand_dims(x.var(0, ddof=1), 0) + 1e-7)
|
| 265 |
+
for x in features
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
# convert into correct format for padding
|
| 269 |
+
encoded_inputs = BatchFeature({"input_features": features})
|
| 270 |
+
|
| 271 |
+
padded_inputs = self.pad(
|
| 272 |
+
encoded_inputs,
|
| 273 |
+
padding=padding,
|
| 274 |
+
max_length=max_length,
|
| 275 |
+
truncation=truncation,
|
| 276 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 277 |
+
return_attention_mask=True,
|
| 278 |
+
return_tensors="np",
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# SeamlessM4T needs to process extracted features
|
| 282 |
+
input_features = padded_inputs.get("input_features")
|
| 283 |
+
attention_mask = padded_inputs.pop("attention_mask")
|
| 284 |
+
|
| 285 |
+
batch_size, num_frames, num_channels = input_features.shape
|
| 286 |
+
|
| 287 |
+
remainder = num_frames % self.stride
|
| 288 |
+
if remainder != 0:
|
| 289 |
+
input_features = input_features[:, : num_frames - remainder, :]
|
| 290 |
+
attention_mask = attention_mask[:, : num_frames - remainder]
|
| 291 |
+
|
| 292 |
+
input_features = np.reshape(
|
| 293 |
+
input_features, (batch_size, num_frames // self.stride, num_channels * self.stride)
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
indices = np.arange(0, num_frames - remainder)
|
| 297 |
+
attention_mask = attention_mask[:, indices % self.stride == 1]
|
| 298 |
+
|
| 299 |
+
padded_inputs["input_features"] = input_features
|
| 300 |
+
if return_attention_mask:
|
| 301 |
+
padded_inputs["attention_mask"] = attention_mask
|
| 302 |
+
|
| 303 |
+
if return_tensors is not None:
|
| 304 |
+
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
| 305 |
+
|
| 306 |
+
return padded_inputs
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
__all__ = ["SeamlessM4TFeatureExtractor"]
|
docs/transformers/build/lib/transformers/models/seamless_m4t/tokenization_seamless_m4t.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for SeamlessM4T."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from shutil import copyfile
|
| 19 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import sentencepiece as spm
|
| 22 |
+
|
| 23 |
+
from ...convert_slow_tokenizer import import_protobuf
|
| 24 |
+
from ...tokenization_utils import (
|
| 25 |
+
BatchEncoding,
|
| 26 |
+
PreTokenizedInput,
|
| 27 |
+
PreTrainedTokenizer,
|
| 28 |
+
TextInput,
|
| 29 |
+
)
|
| 30 |
+
from ...tokenization_utils_base import AddedToken
|
| 31 |
+
from ...utils import PaddingStrategy, logging
|
| 32 |
+
from ...utils.import_utils import requires
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
SPIECE_UNDERLINE = "▁"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@requires(backends=("sentencepiece",))
|
| 45 |
+
class SeamlessM4TTokenizer(PreTrainedTokenizer):
|
| 46 |
+
"""
|
| 47 |
+
Construct a SeamlessM4T tokenizer.
|
| 48 |
+
|
| 49 |
+
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
|
| 50 |
+
[SentencePiece](https://github.com/google/sentencepiece).
|
| 51 |
+
|
| 52 |
+
The tokenization method is `<language code> <tokens> <eos>` for source language documents, and `<eos> <language
|
| 53 |
+
code> <tokens> <eos>` for target language documents.
|
| 54 |
+
|
| 55 |
+
Examples:
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
>>> from transformers import SeamlessM4TTokenizer
|
| 59 |
+
|
| 60 |
+
>>> tokenizer = SeamlessM4TTokenizer.from_pretrained(
|
| 61 |
+
... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra"
|
| 62 |
+
... )
|
| 63 |
+
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
| 64 |
+
>>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
|
| 65 |
+
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
vocab_file (`str`):
|
| 70 |
+
Path to the vocabulary file.
|
| 71 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 72 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 73 |
+
|
| 74 |
+
<Tip>
|
| 75 |
+
|
| 76 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 77 |
+
sequence. The token used is the `cls_token`.
|
| 78 |
+
|
| 79 |
+
</Tip>
|
| 80 |
+
|
| 81 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 82 |
+
The end of sequence token.
|
| 83 |
+
|
| 84 |
+
<Tip>
|
| 85 |
+
|
| 86 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 87 |
+
The token used is the `sep_token`.
|
| 88 |
+
|
| 89 |
+
</Tip>
|
| 90 |
+
|
| 91 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 92 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 93 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 94 |
+
token of a sequence built with special tokens.
|
| 95 |
+
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
| 96 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 97 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 98 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 99 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 100 |
+
token instead.
|
| 101 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 102 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 103 |
+
tokenizer_file (`str`, *optional*):
|
| 104 |
+
The path to a tokenizer file to use instead of the vocab file.
|
| 105 |
+
src_lang (`str`, *optional*, defaults to `"eng"`):
|
| 106 |
+
The language to use as source language for translation.
|
| 107 |
+
tgt_lang (`str`, *optional*, defaults to `"fra"`):
|
| 108 |
+
The language to use as target language for translation.
|
| 109 |
+
sp_model_kwargs (`Dict[str, Any]`, *optional*):
|
| 110 |
+
Additional keyword arguments to pass to the model initialization.
|
| 111 |
+
additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*):
|
| 112 |
+
A tuple or a list of additional special tokens. Can be used to specify the list of languages that will be
|
| 113 |
+
supported by the tokenizer.
|
| 114 |
+
add_prefix_space (`bool`, *optional*, defaults to `True`):
|
| 115 |
+
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
| 116 |
+
other word.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 120 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 121 |
+
|
| 122 |
+
prefix_tokens: List[int] = []
|
| 123 |
+
suffix_tokens: List[int] = []
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
vocab_file,
|
| 128 |
+
bos_token="<s>",
|
| 129 |
+
eos_token="</s>",
|
| 130 |
+
sep_token="</s>",
|
| 131 |
+
cls_token="<s>",
|
| 132 |
+
unk_token="<unk>",
|
| 133 |
+
pad_token="<pad>",
|
| 134 |
+
tokenizer_file=None,
|
| 135 |
+
src_lang="eng",
|
| 136 |
+
tgt_lang="fra",
|
| 137 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 138 |
+
additional_special_tokens=None,
|
| 139 |
+
add_prefix_space=True,
|
| 140 |
+
**kwargs,
|
| 141 |
+
):
|
| 142 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 143 |
+
# Add this unused argument to keep some important Copied from statements
|
| 144 |
+
self.legacy = False
|
| 145 |
+
self.vocab_file = vocab_file
|
| 146 |
+
|
| 147 |
+
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
|
| 148 |
+
|
| 149 |
+
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
|
| 150 |
+
# -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
|
| 151 |
+
# spm | '<unk>' | '<s>' | '</s>' | 'an' | 'en' | '_d' | 'er' | 'in' | '_s' | '_a'
|
| 152 |
+
# fairseq | '<pad>' | '<unk>' | '<s>' | '</s>' | 'an' | 'en' | '▁d' | 'er' | 'in' | '▁s'
|
| 153 |
+
|
| 154 |
+
# Mimic fairseq token-to-id alignment for the first 4 token
|
| 155 |
+
self._added_tokens_decoder = {
|
| 156 |
+
0: AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token,
|
| 157 |
+
1: AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token,
|
| 158 |
+
2: AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token,
|
| 159 |
+
3: AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token,
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# The first "real" token "an" has position 4 in the original fairseq vocab and position 3 in the spm vocab
|
| 163 |
+
self.fairseq_offset = 1
|
| 164 |
+
|
| 165 |
+
self.sp_model_size = len(self.sp_model)
|
| 166 |
+
|
| 167 |
+
self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang
|
| 168 |
+
self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang
|
| 169 |
+
self.add_prefix_space = add_prefix_space
|
| 170 |
+
|
| 171 |
+
super().__init__(
|
| 172 |
+
bos_token=bos_token,
|
| 173 |
+
eos_token=eos_token,
|
| 174 |
+
unk_token=unk_token,
|
| 175 |
+
sep_token=sep_token,
|
| 176 |
+
cls_token=cls_token,
|
| 177 |
+
pad_token=pad_token,
|
| 178 |
+
tokenizer_file=tokenizer_file,
|
| 179 |
+
src_lang=src_lang,
|
| 180 |
+
tgt_lang=tgt_lang,
|
| 181 |
+
additional_special_tokens=additional_special_tokens,
|
| 182 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 183 |
+
add_prefix_space=add_prefix_space,
|
| 184 |
+
**kwargs,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
self.set_src_lang_special_tokens(self._src_lang)
|
| 188 |
+
self.set_tgt_lang_special_tokens(self._tgt_lang)
|
| 189 |
+
|
| 190 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__getstate__
|
| 191 |
+
def __getstate__(self):
|
| 192 |
+
state = self.__dict__.copy()
|
| 193 |
+
state["sp_model"] = None
|
| 194 |
+
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
| 195 |
+
return state
|
| 196 |
+
|
| 197 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__setstate__
|
| 198 |
+
def __setstate__(self, d):
|
| 199 |
+
self.__dict__ = d
|
| 200 |
+
|
| 201 |
+
# for backward compatibility
|
| 202 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 203 |
+
self.sp_model_kwargs = {}
|
| 204 |
+
|
| 205 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 206 |
+
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def vocab_size(self):
|
| 210 |
+
return len(self.sp_model)
|
| 211 |
+
|
| 212 |
+
def __call__(
|
| 213 |
+
self,
|
| 214 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 215 |
+
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
| 216 |
+
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 217 |
+
text_pair_target: Optional[
|
| 218 |
+
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
|
| 219 |
+
] = None,
|
| 220 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
| 221 |
+
pad_to_multiple_of: Optional[int] = 2,
|
| 222 |
+
src_lang: Optional[str] = None,
|
| 223 |
+
tgt_lang: Optional[str] = None,
|
| 224 |
+
**kwargs,
|
| 225 |
+
):
|
| 226 |
+
"""
|
| 227 |
+
Args:
|
| 228 |
+
text (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
| 229 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 230 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 231 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 232 |
+
text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
| 233 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 234 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 235 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 236 |
+
text_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
| 237 |
+
The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
|
| 238 |
+
list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
|
| 239 |
+
you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 240 |
+
text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
| 241 |
+
The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
|
| 242 |
+
list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
|
| 243 |
+
you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 244 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 245 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 246 |
+
index) among:
|
| 247 |
+
|
| 248 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 249 |
+
sequence if provided).
|
| 250 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 251 |
+
acceptable input length for the model if that argument is not provided.
|
| 252 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
| 253 |
+
lengths).
|
| 254 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 255 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 256 |
+
|
| 257 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 258 |
+
`>= 7.5` (Volta).
|
| 259 |
+
src_lang (`str`, *optional*):
|
| 260 |
+
A string representing the source language. If not specified, the last `src_lang` specified (either
|
| 261 |
+
during initialization or when calling this tokenizer) will be used.
|
| 262 |
+
tgt_lang (`str`, *optional*):
|
| 263 |
+
A string representing the target language. If not specified, the last `tgt_lang` specified (either
|
| 264 |
+
during initialization or when calling this tokenizer) will be used.
|
| 265 |
+
kwargs (*optional*):
|
| 266 |
+
Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizer.__call__`].
|
| 267 |
+
"""
|
| 268 |
+
if src_lang is not None:
|
| 269 |
+
self.src_lang = src_lang
|
| 270 |
+
if tgt_lang is not None:
|
| 271 |
+
self.tgt_lang = tgt_lang
|
| 272 |
+
|
| 273 |
+
output = super().__call__(
|
| 274 |
+
text=text,
|
| 275 |
+
text_pair=text_pair,
|
| 276 |
+
text_target=text_target,
|
| 277 |
+
text_pair_target=text_pair_target,
|
| 278 |
+
padding=padding,
|
| 279 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 280 |
+
**kwargs,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
return BatchEncoding(output, tensor_type=kwargs.get("return_tensors"))
|
| 284 |
+
|
| 285 |
+
@property
|
| 286 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang
|
| 287 |
+
def src_lang(self) -> str:
|
| 288 |
+
return self._src_lang
|
| 289 |
+
|
| 290 |
+
@src_lang.setter
|
| 291 |
+
def src_lang(self, new_src_lang: str) -> None:
|
| 292 |
+
if "__" not in new_src_lang:
|
| 293 |
+
self._src_lang = f"__{new_src_lang}__"
|
| 294 |
+
else:
|
| 295 |
+
self._src_lang = new_src_lang
|
| 296 |
+
self.set_src_lang_special_tokens(self._src_lang)
|
| 297 |
+
|
| 298 |
+
@property
|
| 299 |
+
def tgt_lang(self) -> str:
|
| 300 |
+
return self._tgt_lang
|
| 301 |
+
|
| 302 |
+
@tgt_lang.setter
|
| 303 |
+
def tgt_lang(self, new_tgt_lang: str) -> None:
|
| 304 |
+
if "__" not in new_tgt_lang:
|
| 305 |
+
self._tgt_lang = f"__{new_tgt_lang}__"
|
| 306 |
+
else:
|
| 307 |
+
self._tgt_lang = new_tgt_lang
|
| 308 |
+
self.set_tgt_lang_special_tokens(self._tgt_lang)
|
| 309 |
+
|
| 310 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.get_special_tokens_mask
|
| 311 |
+
def get_special_tokens_mask(
|
| 312 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 313 |
+
) -> List[int]:
|
| 314 |
+
"""
|
| 315 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 316 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
token_ids_0 (`List[int]`):
|
| 320 |
+
List of IDs.
|
| 321 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 322 |
+
Optional second list of IDs for sequence pairs.
|
| 323 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 324 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
if already_has_special_tokens:
|
| 331 |
+
return super().get_special_tokens_mask(
|
| 332 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
prefix_ones = [1] * len(self.prefix_tokens)
|
| 336 |
+
suffix_ones = [1] * len(self.suffix_tokens)
|
| 337 |
+
if token_ids_1 is None:
|
| 338 |
+
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
| 339 |
+
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
| 340 |
+
|
| 341 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.build_inputs_with_special_tokens
|
| 342 |
+
def build_inputs_with_special_tokens(
|
| 343 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 344 |
+
) -> List[int]:
|
| 345 |
+
"""
|
| 346 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 347 |
+
adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence:
|
| 348 |
+
|
| 349 |
+
- `input_ids` (for encoder) `X [eos, src_lang_code]`
|
| 350 |
+
- `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
|
| 351 |
+
|
| 352 |
+
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
|
| 353 |
+
separator.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
token_ids_0 (`List[int]`):
|
| 357 |
+
List of IDs to which the special tokens will be added.
|
| 358 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 359 |
+
Optional second list of IDs for sequence pairs.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 363 |
+
"""
|
| 364 |
+
if token_ids_1 is None:
|
| 365 |
+
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
| 366 |
+
# We don't expect to process pairs, but leave the pair logic for API consistency
|
| 367 |
+
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
| 368 |
+
|
| 369 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.create_token_type_ids_from_sequences
|
| 370 |
+
def create_token_type_ids_from_sequences(
|
| 371 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 372 |
+
) -> List[int]:
|
| 373 |
+
"""
|
| 374 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
|
| 375 |
+
make use of token type ids, therefore a list of zeros is returned.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
token_ids_0 (`List[int]`):
|
| 379 |
+
List of IDs.
|
| 380 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 381 |
+
Optional second list of IDs for sequence pairs.
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
`List[int]`: List of zeros.
|
| 385 |
+
|
| 386 |
+
"""
|
| 387 |
+
|
| 388 |
+
sep = [self.sep_token_id]
|
| 389 |
+
cls = [self.cls_token_id]
|
| 390 |
+
|
| 391 |
+
if token_ids_1 is None:
|
| 392 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 393 |
+
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
| 394 |
+
|
| 395 |
+
def _build_translation_inputs(
|
| 396 |
+
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
| 397 |
+
):
|
| 398 |
+
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
| 399 |
+
if src_lang is None or tgt_lang is None:
|
| 400 |
+
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model.")
|
| 401 |
+
self.src_lang = src_lang
|
| 402 |
+
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
| 403 |
+
if "__" not in tgt_lang:
|
| 404 |
+
tgt_lang = f"__{tgt_lang}__"
|
| 405 |
+
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
| 406 |
+
inputs["forced_bos_token_id"] = tgt_lang_id
|
| 407 |
+
return inputs
|
| 408 |
+
|
| 409 |
+
def get_vocab(self):
|
| 410 |
+
vocab = {
|
| 411 |
+
self.convert_ids_to_tokens(i): i for i in range(self.fairseq_offset, self.vocab_size + self.fairseq_offset)
|
| 412 |
+
}
|
| 413 |
+
vocab.update(self.added_tokens_encoder)
|
| 414 |
+
return vocab
|
| 415 |
+
|
| 416 |
+
@property
|
| 417 |
+
def unk_token_length(self):
|
| 418 |
+
return len(self.sp_model.encode(str(self.unk_token)))
|
| 419 |
+
|
| 420 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
|
| 421 |
+
def get_spm_processor(self, from_slow=False):
|
| 422 |
+
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 423 |
+
if self.legacy or from_slow: # no dependency on protobuf
|
| 424 |
+
tokenizer.Load(self.vocab_file)
|
| 425 |
+
return tokenizer
|
| 426 |
+
|
| 427 |
+
with open(self.vocab_file, "rb") as f:
|
| 428 |
+
sp_model = f.read()
|
| 429 |
+
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
|
| 430 |
+
model = model_pb2.ModelProto.FromString(sp_model)
|
| 431 |
+
normalizer_spec = model_pb2.NormalizerSpec()
|
| 432 |
+
normalizer_spec.add_dummy_prefix = False
|
| 433 |
+
model.normalizer_spec.MergeFrom(normalizer_spec)
|
| 434 |
+
sp_model = model.SerializeToString()
|
| 435 |
+
tokenizer.LoadFromSerializedProto(sp_model)
|
| 436 |
+
return tokenizer
|
| 437 |
+
|
| 438 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
|
| 439 |
+
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
|
| 440 |
+
"""
|
| 441 |
+
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
|
| 442 |
+
first token is special.
|
| 443 |
+
"""
|
| 444 |
+
if self.legacy or len(text) == 0:
|
| 445 |
+
return super().tokenize(text, **kwargs)
|
| 446 |
+
|
| 447 |
+
text = text.replace(SPIECE_UNDERLINE, " ")
|
| 448 |
+
if self.add_prefix_space:
|
| 449 |
+
text = SPIECE_UNDERLINE + text
|
| 450 |
+
|
| 451 |
+
tokens = super().tokenize(text, **kwargs)
|
| 452 |
+
|
| 453 |
+
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
|
| 454 |
+
tokens = tokens[1:]
|
| 455 |
+
return tokens
|
| 456 |
+
|
| 457 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
|
| 458 |
+
def _tokenize(self, text, **kwargs):
|
| 459 |
+
"""
|
| 460 |
+
Returns a tokenized string.
|
| 461 |
+
|
| 462 |
+
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
|
| 463 |
+
SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
|
| 464 |
+
`['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
|
| 465 |
+
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
| 466 |
+
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
| 467 |
+
"""
|
| 468 |
+
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
|
| 469 |
+
return self.sp_model.encode(text, out_type=str)
|
| 470 |
+
|
| 471 |
+
# 1. Encode string + prefix ex: "<unk> Hey"
|
| 472 |
+
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
| 473 |
+
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
| 474 |
+
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
| 475 |
+
|
| 476 |
+
def _convert_token_to_id(self, token):
|
| 477 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 478 |
+
spm_id = self.sp_model.PieceToId(token)
|
| 479 |
+
|
| 480 |
+
# Need to return unknown token if the SP model returned 0
|
| 481 |
+
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
|
| 482 |
+
|
| 483 |
+
def _convert_id_to_token(self, index):
|
| 484 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 485 |
+
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
| 486 |
+
|
| 487 |
+
def convert_tokens_to_string(self, tokens):
|
| 488 |
+
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
| 489 |
+
# since we manually add the prefix space, we have to remove it when decoding
|
| 490 |
+
if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
|
| 491 |
+
tokens[0] = tokens[0][1:]
|
| 492 |
+
|
| 493 |
+
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
| 494 |
+
return out_string
|
| 495 |
+
|
| 496 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.save_vocabulary
|
| 497 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 498 |
+
if not os.path.isdir(save_directory):
|
| 499 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 500 |
+
return
|
| 501 |
+
out_vocab_file = os.path.join(
|
| 502 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 506 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 507 |
+
elif not os.path.isfile(self.vocab_file):
|
| 508 |
+
with open(out_vocab_file, "wb") as fi:
|
| 509 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 510 |
+
fi.write(content_spiece_model)
|
| 511 |
+
|
| 512 |
+
return (out_vocab_file,)
|
| 513 |
+
|
| 514 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.prepare_seq2seq_batch with eng_Latn->eng, fra_Latn->fra
|
| 515 |
+
def prepare_seq2seq_batch(
|
| 516 |
+
self,
|
| 517 |
+
src_texts: List[str],
|
| 518 |
+
src_lang: str = "eng",
|
| 519 |
+
tgt_texts: Optional[List[str]] = None,
|
| 520 |
+
tgt_lang: str = "fra",
|
| 521 |
+
**kwargs,
|
| 522 |
+
) -> BatchEncoding:
|
| 523 |
+
self.src_lang = src_lang
|
| 524 |
+
self.tgt_lang = tgt_lang
|
| 525 |
+
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
| 526 |
+
|
| 527 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_input_mode
|
| 528 |
+
def _switch_to_input_mode(self):
|
| 529 |
+
return self.set_src_lang_special_tokens(self.src_lang)
|
| 530 |
+
|
| 531 |
+
# Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_target_mode
|
| 532 |
+
def _switch_to_target_mode(self):
|
| 533 |
+
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
| 534 |
+
|
| 535 |
+
def set_src_lang_special_tokens(self, src_lang) -> None:
|
| 536 |
+
"""Reset the special tokens to the source lang setting.
|
| 537 |
+
Prefix=[src_lang_code], suffix = [eos]
|
| 538 |
+
"""
|
| 539 |
+
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
|
| 540 |
+
self.init_kwargs["src_lang"] = src_lang
|
| 541 |
+
|
| 542 |
+
if self.cur_lang_code == self.unk_token_id:
|
| 543 |
+
logger.warning_once(
|
| 544 |
+
f"`src_lang={src_lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id."
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
self.prefix_tokens = [self.cur_lang_code]
|
| 548 |
+
self.suffix_tokens = [self.eos_token_id]
|
| 549 |
+
|
| 550 |
+
# https://github.com/facebookresearch/fairseq2/blob/c53f18e6be6b8b46b722f2249b8397b7eccd7ad3/src/fairseq2/models/nllb/tokenizer.py#L112-L116
|
| 551 |
+
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
| 552 |
+
"""Reset the special tokens to the target lang setting.
|
| 553 |
+
Prefix=[eos, tgt_lang_code] and suffix=[eos].
|
| 554 |
+
"""
|
| 555 |
+
self.cur_lang_code = self.convert_tokens_to_ids(lang)
|
| 556 |
+
self.init_kwargs["tgt_lang"] = lang
|
| 557 |
+
|
| 558 |
+
if self.cur_lang_code == self.unk_token_id:
|
| 559 |
+
logger.warning_once(
|
| 560 |
+
f"`tgt_lang={lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id."
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
self.prefix_tokens = [self.eos_token_id, self.cur_lang_code]
|
| 564 |
+
self.suffix_tokens = [self.eos_token_id]
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
__all__ = ["SeamlessM4TTokenizer"]
|
docs/transformers/build/lib/transformers/models/seamless_m4t_v2/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_seamless_m4t_v2 import *
|
| 22 |
+
from .modeling_seamless_m4t_v2 import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""SeamlessM4Tv2 model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SeamlessM4Tv2Config(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`~SeamlessM4Tv2Model`]. It is used to instantiate
|
| 27 |
+
an SeamlessM4Tv2 model according to the specified arguments, defining the model architecture. Instantiating a
|
| 28 |
+
configuration with the defaults will yield a similar configuration to that of the SeamlessM4Tv2
|
| 29 |
+
[""](https://huggingface.co/"") architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 256102):
|
| 37 |
+
Vocabulary size of the text modality of the SeamlessM4Tv2 model. Defines the number of different tokens
|
| 38 |
+
that can be represented by the `inputs_ids` passed when calling [`~SeamlessM4Tv2Model`],
|
| 39 |
+
[`~SeamlessM4Tv2ForTextToSpeech`] or [`~SeamlessM4Tv2ForTextToText`].
|
| 40 |
+
t2u_vocab_size (`int`, *optional*, defaults to 10082):
|
| 41 |
+
Unit vocabulary size of the SeamlessM4Tv2 model. Defines the number of different "unit tokens" that can be
|
| 42 |
+
represented by the `inputs_ids` passed when calling the Text-To-Units sub-model of [`~SeamlessM4Tv2Model`],
|
| 43 |
+
[`~SeamlessM4Tv2ForSpeechToSpeech`] or [`~SeamlessM4Tv2ForTextToSpeech`].
|
| 44 |
+
char_vocab_size (`int`, *optional*, defaults to 10943):
|
| 45 |
+
Character vocabulary size of the SeamlessM4Tv2 model. Defines the number of different character tokens that
|
| 46 |
+
can be represented by the `char_inputs_ids` passed when calling the Text-To-Units sub-model of
|
| 47 |
+
[`~SeamlessM4Tv2Model`], [`~SeamlessM4Tv2ForSpeechToSpeech`] or [`~SeamlessM4Tv2ForTextToSpeech`].
|
| 48 |
+
|
| 49 |
+
> Parameters shared across sub-models
|
| 50 |
+
|
| 51 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
| 52 |
+
Dimensionality of the "intermediate" layers in the architecture.
|
| 53 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 54 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 55 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 56 |
+
The epsilon used by the layer normalization layers.
|
| 57 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 58 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 59 |
+
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
| 60 |
+
The maximum sequence length that this model text encoder and decoder might ever be used with. Typically set
|
| 61 |
+
this to something large just in case (e.g., 512 or 1024 or 2048).
|
| 62 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether the model is used as an encoder/decoder or not.
|
| 64 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.05):
|
| 65 |
+
The LayerDrop probability for the encoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 66 |
+
for more details.
|
| 67 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.05):
|
| 68 |
+
The LayerDrop probability for the decoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 69 |
+
for more details.
|
| 70 |
+
activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
|
| 71 |
+
The non-linear activation function (function or string) in the decoder and feed-forward layers. If string,
|
| 72 |
+
`"gelu"`, `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
|
| 73 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 74 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, decoder, and pooler.
|
| 75 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 76 |
+
The dropout probability for all attention layers.
|
| 77 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 78 |
+
The dropout probability for all activation layers in the model.
|
| 79 |
+
scale_embedding (`bool`, *optional*, defaults to `True`):
|
| 80 |
+
Scale embeddings by diving by sqrt(d_model).
|
| 81 |
+
|
| 82 |
+
> Text encoder and text decoder specific parameters
|
| 83 |
+
|
| 84 |
+
encoder_layers (`int`, *optional*, defaults to 24):
|
| 85 |
+
Number of hidden layers in the Transformer text encoder.
|
| 86 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 8192):
|
| 87 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text encoder.
|
| 88 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 89 |
+
Number of attention heads for each attention layer in the Transformer text encoder.
|
| 90 |
+
decoder_layers (`int`, *optional*, defaults to 24):
|
| 91 |
+
Number of hidden layers in the Transformer text decoder.
|
| 92 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 8192):
|
| 93 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text decoder.
|
| 94 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 95 |
+
Number of attention heads for each attention layer in the Transformer text decoder.
|
| 96 |
+
decoder_start_token_id (`int`, *optional*, defaults to 3):
|
| 97 |
+
If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only
|
| 98 |
+
applied in the text decoder.
|
| 99 |
+
max_new_tokens (`int`, *optional*, defaults to 256):
|
| 100 |
+
The maximum numbers of text tokens to generate, ignoring the number of tokens in the prompt.
|
| 101 |
+
pad_token_id (`int`, *optional*, defaults to 0):
|
| 102 |
+
The id of the _padding_ text token. Only applied to the text-decoder model.
|
| 103 |
+
bos_token_id (`int`, *optional*, defaults to 2):
|
| 104 |
+
The id of the _beginning-of-stream_ text token. Only applied to the text-decoder model.
|
| 105 |
+
eos_token_id (`int`, *optional*, defaults to 3):
|
| 106 |
+
The id of the _end-of-stream_ text token. Only applied to the text-decoder model.
|
| 107 |
+
|
| 108 |
+
> Speech encoder specific parameters
|
| 109 |
+
|
| 110 |
+
speech_encoder_layers (`int`, *optional*, defaults to 24):
|
| 111 |
+
Number of hidden layers in the Transformer speech encoder.
|
| 112 |
+
speech_encoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 113 |
+
Number of attention heads for each attention layer in the Transformer speech encoder.
|
| 114 |
+
speech_encoder_intermediate_size (`int`, *optional*, defaults to 4096):
|
| 115 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer speech encoder.
|
| 116 |
+
speech_encoder_hidden_act (`str` or `function`, *optional*, defaults to `"swish"`):
|
| 117 |
+
The non-linear activation function (function or string) in the speech encoder. If string, `"gelu"`,
|
| 118 |
+
`"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
|
| 119 |
+
speech_encoder_dropout (`float`, *optional*, defaults to 0.0):
|
| 120 |
+
The dropout probability for all layers in the speech encoder.
|
| 121 |
+
add_adapter (`bool`, *optional*, defaults to `True`):
|
| 122 |
+
Add an adapter layer on top of the speech encoder.
|
| 123 |
+
speech_encoder_layerdrop (`float`, *optional*, defaults to 0.1):
|
| 124 |
+
The LayerDrop probability for the speech encoder. See the [LayerDrop paper](see
|
| 125 |
+
https://arxiv.org/abs/1909.11556) for more details.
|
| 126 |
+
feature_projection_input_dim (`int`, *optional*, defaults to 160):
|
| 127 |
+
Input dimension of the input feature projection of the speech encoder, i.e the dimension after processing
|
| 128 |
+
input audios with [`SeamlessM4TFeatureExtractor`].
|
| 129 |
+
adaptor_kernel_size (`int`, *optional*, defaults to 8):
|
| 130 |
+
Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
| 131 |
+
adaptor_stride (`int`, *optional*, defaults to 8):
|
| 132 |
+
Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
| 133 |
+
adaptor_dropout (`float`, *optional*, defaults to 0.1):
|
| 134 |
+
The dropout probability for all layers in the speech adapter.
|
| 135 |
+
num_adapter_layers (`int`, *optional*, defaults to 1):
|
| 136 |
+
Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
|
| 137 |
+
True`.
|
| 138 |
+
position_embeddings_type (`str`, *optional*, defaults to `"relative_key"`):
|
| 139 |
+
Can be specified to `relative_key`. If left to `None`, no relative position embedding is applied. Only
|
| 140 |
+
applied to the speech encoder. For more information on `"relative_key"`, please refer to [Self-Attention
|
| 141 |
+
with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
| 142 |
+
conv_depthwise_kernel_size (`int`, *optional*, defaults to 31):
|
| 143 |
+
Kernel size of convolutional depthwise 1D layer in Conformer blocks. Only applied to the speech encoder.
|
| 144 |
+
left_max_position_embeddings (`int`, *optional*, defaults to 64):
|
| 145 |
+
The left clipping value for relative positions.
|
| 146 |
+
right_max_position_embeddings (`int`, *optional*, defaults to 8):
|
| 147 |
+
The right clipping value for relative positions.
|
| 148 |
+
speech_encoder_chunk_size (`int`, *optional*, defaults to 20000): The size of each attention chunk.
|
| 149 |
+
speech_encoder_left_chunk_num (`int`, *optional*, defaults to 128):
|
| 150 |
+
Number of chunks on the left up to which lookahead is allowed.
|
| 151 |
+
|
| 152 |
+
> Text-To-Unit (t2u) model specific parameters
|
| 153 |
+
|
| 154 |
+
t2u_bos_token_id (`int`, *optional*, defaults to 0):
|
| 155 |
+
The id of the _beginning-of-stream_ unit token. Only applied to the text-to-unit seq2seq model.
|
| 156 |
+
t2u_pad_token_id (`int`, *optional*, defaults to 1):
|
| 157 |
+
The id of the _padding_ unit token. Only applied to the text-to-unit seq2seq model.
|
| 158 |
+
t2u_eos_token_id (`int`, *optional*, defaults to 2):
|
| 159 |
+
The id of the _end-of-stream_ unit token. Only applied to the text-to-unit seq2seq model.
|
| 160 |
+
t2u_encoder_layers (`int`, *optional*, defaults to 6):
|
| 161 |
+
Number of hidden layers in the Transformer text-to-unit encoder.
|
| 162 |
+
t2u_encoder_ffn_dim (`int`, *optional*, defaults to 8192):
|
| 163 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit encoder.
|
| 164 |
+
t2u_encoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 165 |
+
Number of attention heads for each attention layer in the Transformer text-to-unit encoder.
|
| 166 |
+
t2u_decoder_layers (`int`, *optional*, defaults to 6):
|
| 167 |
+
Number of hidden layers in the Transformer text-to-unit decoder.
|
| 168 |
+
t2u_decoder_ffn_dim (`int`, *optional*, defaults to 8192):
|
| 169 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit decoder.
|
| 170 |
+
t2u_decoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 171 |
+
Number of attention heads for each attention layer in the Transformer text-to-unit decoder.
|
| 172 |
+
t2u_max_position_embeddings (`int`, *optional*, defaults to 4096):
|
| 173 |
+
The maximum sequence length that this model text-to-unit component might ever be used with. Typically set
|
| 174 |
+
this to something large just in case (e.g., 512 or 1024 or 2048).
|
| 175 |
+
t2u_variance_predictor_embed_dim (`int`, *optional*, defaults to 1024):
|
| 176 |
+
The projection dimension of the text-to-unit's duration predictor.
|
| 177 |
+
t2u_variance_predictor_hidden_dim (`int`, *optional*, defaults to 256):
|
| 178 |
+
Internal dimension of the text-to-unit's duration predictor.
|
| 179 |
+
t2u_variance_predictor_kernel_size (`int`, *optional*, defaults to 3):
|
| 180 |
+
Kernel size of the convolutional layers of the text-to-unit's duration predictor.
|
| 181 |
+
t2u_variance_pred_dropout (`float`, *optional*, defaults to 0.5):
|
| 182 |
+
The dropout probability of the text-to-unit's duration predictor.
|
| 183 |
+
|
| 184 |
+
> Hifi-Gan Vocoder specific parameters
|
| 185 |
+
|
| 186 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 187 |
+
The sampling rate at which the output audio will be generated, expressed in hertz (Hz).
|
| 188 |
+
upsample_initial_channel (`int`, *optional*, defaults to 512):
|
| 189 |
+
The number of input channels into the hifi-gan upsampling network. Applies to the vocoder only.
|
| 190 |
+
upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[5, 4, 4, 2, 2]`):
|
| 191 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the vocoder upsampling network.
|
| 192 |
+
The length of *upsample_rates* defines the number of convolutional layers and has to match the length of
|
| 193 |
+
*upsample_kernel_sizes*. Applies to the vocoder only.
|
| 194 |
+
upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[11, 8, 8, 4, 4]`):
|
| 195 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the vocoder upsampling
|
| 196 |
+
network. The length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match
|
| 197 |
+
the length of *upsample_rates*. Applies to the vocoder only.
|
| 198 |
+
resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`):
|
| 199 |
+
A tuple of integers defining the kernel sizes of the vocoder 1D convolutional layers in the multi-receptive
|
| 200 |
+
field fusion (MRF) module. Applies to the vocoder only.
|
| 201 |
+
resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
|
| 202 |
+
A nested tuple of integers defining the dilation rates of the vocoder dilated 1D convolutional layers in
|
| 203 |
+
the multi-receptive field fusion (MRF) module. Applies to the vocoder only.
|
| 204 |
+
leaky_relu_slope (`float`, *optional*, defaults to 0.1):
|
| 205 |
+
The angle of the negative slope used by the leaky ReLU activation in the vocoder. Applies to the vocoder
|
| 206 |
+
only.
|
| 207 |
+
unit_hifi_gan_vocab_size (`int`, *optional*, defaults to 10000):
|
| 208 |
+
Vocabulary size of the SeamlessM4Tv2 vocoder. Defines the number of different unit tokens that can be
|
| 209 |
+
represented by the `inputs_ids` passed when calling the vocoder of [`~SeamlessM4Tv2Model`],
|
| 210 |
+
[`~SeamlessM4Tv2ForSpeechToSpeech`] or [`~SeamlessM4Tv2ForTextToSpeech`].
|
| 211 |
+
unit_embed_dim (`int`, *optional*, defaults to 1280):
|
| 212 |
+
The projection dimension of the input ids given to the hifi-gan vocoder. Applies to the vocoder only.
|
| 213 |
+
lang_embed_dim (`int`, *optional*, defaults to 256):
|
| 214 |
+
The projection dimension of the target language given to the hifi-gan vocoder. Applies to the vocoder only.
|
| 215 |
+
spkr_embed_dim (`int`, *optional*, defaults to 256):
|
| 216 |
+
The projection dimension of the speaker id given to the hifi-gan vocoder. Applies to the vocoder only.
|
| 217 |
+
vocoder_num_langs (`int`, *optional*, defaults to 36):
|
| 218 |
+
Number of langs supported by the vocoder. Might be different from `t2u_num_langs`.
|
| 219 |
+
vocoder_num_spkrs (`int`, *optional*, defaults to 200):
|
| 220 |
+
Number of speakers supported by the vocoder.
|
| 221 |
+
variance_predictor_kernel_size (`int`, *optional*, defaults to 3):
|
| 222 |
+
Kernel size of the duration predictor. Applies to the vocoder only.
|
| 223 |
+
var_pred_dropout (`float`, *optional*, defaults to 0.5):
|
| 224 |
+
The dropout probability of the duration predictor. Applies to the vocoder only.
|
| 225 |
+
vocoder_offset (`int`, *optional*, defaults to 4):
|
| 226 |
+
Offset the unit token ids by this number to account for symbol tokens. Applies to the vocoder only.
|
| 227 |
+
|
| 228 |
+
```python
|
| 229 |
+
>>> from transformers import SeamlessM4Tv2Model, SeamlessM4Tv2Config
|
| 230 |
+
|
| 231 |
+
>>> # Initializing a SeamlessM4Tv2 "" style configuration
|
| 232 |
+
>>> configuration = SeamlessM4Tv2Config()
|
| 233 |
+
|
| 234 |
+
>>> # Initializing a model from the "" style configuration
|
| 235 |
+
>>> model = SeamlessM4Tv2Model(configuration)
|
| 236 |
+
|
| 237 |
+
>>> # Accessing the model configuration
|
| 238 |
+
>>> configuration = model.config
|
| 239 |
+
```"""
|
| 240 |
+
|
| 241 |
+
model_type = "seamless_m4t_v2"
|
| 242 |
+
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
vocab_size=256102,
|
| 246 |
+
t2u_vocab_size=10082,
|
| 247 |
+
char_vocab_size=10943,
|
| 248 |
+
# shared config
|
| 249 |
+
hidden_size=1024,
|
| 250 |
+
initializer_range=0.02,
|
| 251 |
+
layer_norm_eps=1e-5,
|
| 252 |
+
use_cache=True,
|
| 253 |
+
max_position_embeddings=4096,
|
| 254 |
+
is_encoder_decoder=True,
|
| 255 |
+
encoder_layerdrop=0.05,
|
| 256 |
+
decoder_layerdrop=0.05,
|
| 257 |
+
activation_function="relu",
|
| 258 |
+
dropout=0.1,
|
| 259 |
+
attention_dropout=0.1,
|
| 260 |
+
activation_dropout=0.0,
|
| 261 |
+
scale_embedding=True,
|
| 262 |
+
# text encoder|decoder
|
| 263 |
+
encoder_layers=24,
|
| 264 |
+
encoder_ffn_dim=8192,
|
| 265 |
+
encoder_attention_heads=16,
|
| 266 |
+
decoder_layers=24,
|
| 267 |
+
decoder_ffn_dim=8192,
|
| 268 |
+
decoder_attention_heads=16,
|
| 269 |
+
decoder_start_token_id=3,
|
| 270 |
+
max_new_tokens=256,
|
| 271 |
+
pad_token_id=0,
|
| 272 |
+
bos_token_id=2,
|
| 273 |
+
eos_token_id=3,
|
| 274 |
+
# speech_encoder
|
| 275 |
+
speech_encoder_layers=24,
|
| 276 |
+
speech_encoder_attention_heads=16,
|
| 277 |
+
speech_encoder_intermediate_size=4096,
|
| 278 |
+
speech_encoder_hidden_act="swish",
|
| 279 |
+
speech_encoder_dropout=0.0,
|
| 280 |
+
add_adapter=True,
|
| 281 |
+
speech_encoder_layerdrop=0.1,
|
| 282 |
+
feature_projection_input_dim=160,
|
| 283 |
+
adaptor_kernel_size=8,
|
| 284 |
+
adaptor_stride=8,
|
| 285 |
+
adaptor_dropout=0.1,
|
| 286 |
+
num_adapter_layers=1,
|
| 287 |
+
position_embeddings_type="relative_key",
|
| 288 |
+
conv_depthwise_kernel_size=31,
|
| 289 |
+
left_max_position_embeddings=64,
|
| 290 |
+
right_max_position_embeddings=8,
|
| 291 |
+
speech_encoder_chunk_size=20000,
|
| 292 |
+
speech_encoder_left_chunk_num=128,
|
| 293 |
+
# t2u config
|
| 294 |
+
t2u_bos_token_id=0,
|
| 295 |
+
t2u_pad_token_id=1,
|
| 296 |
+
t2u_eos_token_id=2,
|
| 297 |
+
t2u_encoder_layers=6,
|
| 298 |
+
t2u_encoder_ffn_dim=8192,
|
| 299 |
+
t2u_encoder_attention_heads=16,
|
| 300 |
+
t2u_decoder_layers=6,
|
| 301 |
+
t2u_decoder_ffn_dim=8192,
|
| 302 |
+
t2u_decoder_attention_heads=16,
|
| 303 |
+
t2u_max_position_embeddings=4096,
|
| 304 |
+
t2u_variance_predictor_embed_dim=1024,
|
| 305 |
+
t2u_variance_predictor_hidden_dim=256,
|
| 306 |
+
t2u_variance_predictor_kernel_size=3,
|
| 307 |
+
t2u_variance_pred_dropout=0.5,
|
| 308 |
+
# hifi-gan vocoder config
|
| 309 |
+
sampling_rate=16000,
|
| 310 |
+
upsample_initial_channel=512,
|
| 311 |
+
upsample_rates=[5, 4, 4, 2, 2],
|
| 312 |
+
upsample_kernel_sizes=[11, 8, 8, 4, 4],
|
| 313 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 314 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 315 |
+
leaky_relu_slope=0.1,
|
| 316 |
+
# specific to Code Hifi-Gan
|
| 317 |
+
unit_hifi_gan_vocab_size=10000,
|
| 318 |
+
unit_embed_dim=1280,
|
| 319 |
+
lang_embed_dim=256,
|
| 320 |
+
spkr_embed_dim=256,
|
| 321 |
+
vocoder_num_langs=36,
|
| 322 |
+
vocoder_num_spkrs=200,
|
| 323 |
+
variance_predictor_kernel_size=3,
|
| 324 |
+
var_pred_dropout=0.5,
|
| 325 |
+
vocoder_offset=4,
|
| 326 |
+
**kwargs,
|
| 327 |
+
):
|
| 328 |
+
# overall_config
|
| 329 |
+
self.vocab_size = vocab_size
|
| 330 |
+
self.t2u_vocab_size = t2u_vocab_size
|
| 331 |
+
self.char_vocab_size = char_vocab_size
|
| 332 |
+
self.hidden_size = hidden_size
|
| 333 |
+
self.initializer_range = initializer_range
|
| 334 |
+
self.layer_norm_eps = layer_norm_eps
|
| 335 |
+
self.max_position_embeddings = max_position_embeddings
|
| 336 |
+
self.use_cache = use_cache
|
| 337 |
+
self.max_new_tokens = max_new_tokens
|
| 338 |
+
self.encoder_layerdrop = encoder_layerdrop
|
| 339 |
+
self.decoder_layerdrop = decoder_layerdrop
|
| 340 |
+
self.activation_function = activation_function
|
| 341 |
+
self.dropout = dropout
|
| 342 |
+
self.attention_dropout = attention_dropout
|
| 343 |
+
self.activation_dropout = activation_dropout
|
| 344 |
+
self.scale_embedding = scale_embedding
|
| 345 |
+
# for proper config init
|
| 346 |
+
self.num_attention_heads = decoder_attention_heads
|
| 347 |
+
self.num_hidden_layers = decoder_layers
|
| 348 |
+
|
| 349 |
+
# text|unit encoder|decoder
|
| 350 |
+
self.encoder_layers = encoder_layers
|
| 351 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 352 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 353 |
+
self.decoder_layers = decoder_layers
|
| 354 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 355 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 356 |
+
|
| 357 |
+
# speech_encoder
|
| 358 |
+
self.speech_encoder_layers = speech_encoder_layers
|
| 359 |
+
self.speech_encoder_hidden_act = speech_encoder_hidden_act
|
| 360 |
+
self.speech_encoder_dropout = speech_encoder_dropout
|
| 361 |
+
self.speech_encoder_attention_heads = speech_encoder_attention_heads
|
| 362 |
+
self.speech_encoder_layerdrop = speech_encoder_layerdrop
|
| 363 |
+
self.speech_encoder_intermediate_size = speech_encoder_intermediate_size
|
| 364 |
+
self.feature_projection_input_dim = feature_projection_input_dim
|
| 365 |
+
self.adaptor_kernel_size = adaptor_kernel_size
|
| 366 |
+
self.adaptor_stride = adaptor_stride
|
| 367 |
+
self.adaptor_dropout = adaptor_dropout
|
| 368 |
+
self.num_adapter_layers = num_adapter_layers
|
| 369 |
+
self.position_embeddings_type = position_embeddings_type
|
| 370 |
+
self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
|
| 371 |
+
self.add_adapter = add_adapter
|
| 372 |
+
self.left_max_position_embeddings = left_max_position_embeddings
|
| 373 |
+
self.right_max_position_embeddings = right_max_position_embeddings
|
| 374 |
+
self.speech_encoder_chunk_size = speech_encoder_chunk_size
|
| 375 |
+
self.speech_encoder_left_chunk_num = speech_encoder_left_chunk_num
|
| 376 |
+
|
| 377 |
+
# t2u config
|
| 378 |
+
self.t2u_bos_token_id = t2u_bos_token_id
|
| 379 |
+
self.t2u_pad_token_id = t2u_pad_token_id
|
| 380 |
+
self.t2u_eos_token_id = t2u_eos_token_id
|
| 381 |
+
self.t2u_encoder_layers = t2u_encoder_layers
|
| 382 |
+
self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim
|
| 383 |
+
self.t2u_encoder_attention_heads = t2u_encoder_attention_heads
|
| 384 |
+
self.t2u_decoder_layers = t2u_decoder_layers
|
| 385 |
+
self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim
|
| 386 |
+
self.t2u_decoder_attention_heads = t2u_decoder_attention_heads
|
| 387 |
+
self.t2u_max_position_embeddings = t2u_max_position_embeddings
|
| 388 |
+
self.t2u_variance_predictor_embed_dim = t2u_variance_predictor_embed_dim # TODO: add to docstrings
|
| 389 |
+
self.t2u_variance_predictor_hidden_dim = t2u_variance_predictor_hidden_dim # TODO: add to docstrings
|
| 390 |
+
self.t2u_variance_predictor_kernel_size = t2u_variance_predictor_kernel_size # TODO: add to docstrings
|
| 391 |
+
self.t2u_variance_pred_dropout = t2u_variance_pred_dropout # TODO: add to docstrings
|
| 392 |
+
|
| 393 |
+
# hifi-gan vocoder config
|
| 394 |
+
# original parameters specific to Hifi-Gan
|
| 395 |
+
self.sampling_rate = sampling_rate
|
| 396 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 397 |
+
self.upsample_rates = upsample_rates
|
| 398 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 399 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 400 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 401 |
+
self.leaky_relu_slope = leaky_relu_slope
|
| 402 |
+
|
| 403 |
+
# specific to Code Hifi-Gan
|
| 404 |
+
self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size
|
| 405 |
+
self.unit_embed_dim = unit_embed_dim
|
| 406 |
+
self.lang_embed_dim = lang_embed_dim
|
| 407 |
+
self.spkr_embed_dim = spkr_embed_dim
|
| 408 |
+
self.vocoder_num_langs = vocoder_num_langs
|
| 409 |
+
self.vocoder_num_spkrs = vocoder_num_spkrs
|
| 410 |
+
self.variance_predictor_kernel_size = variance_predictor_kernel_size
|
| 411 |
+
self.var_pred_dropout = var_pred_dropout
|
| 412 |
+
self.vocoder_offset = vocoder_offset
|
| 413 |
+
|
| 414 |
+
super().__init__(
|
| 415 |
+
pad_token_id=pad_token_id,
|
| 416 |
+
bos_token_id=bos_token_id,
|
| 417 |
+
eos_token_id=eos_token_id,
|
| 418 |
+
decoder_start_token_id=decoder_start_token_id,
|
| 419 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 420 |
+
max_position_embeddings=max_position_embeddings,
|
| 421 |
+
**kwargs,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
__all__ = ["SeamlessM4Tv2Config"]
|
test.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 2 |
+
swift infer \
|
| 3 |
+
--adapters /root/autodl-tmp/output_7B_Lora/v2-20250608-171618/checkpoint-324\
|
| 4 |
+
--stream true \
|
| 5 |
+
--temperature 0 \
|
| 6 |
+
--max_new_tokens 2048
|
test_qwenOmni.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig, get_template
|
| 10 |
+
from transformers import HfArgumentParser
|
| 11 |
+
from transformers import Qwen2_5OmniProcessor
|
| 12 |
+
from dataset.dataset2 import AudioDataset
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class TestArguments:
|
| 16 |
+
"""
|
| 17 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 18 |
+
"""
|
| 19 |
+
MODEL_PATH = "/root/autodl-tmp/Qwen2.5-Omni-7B" # 基础模型路径
|
| 20 |
+
LORA_PATH = "/root/autodl-tmp/output_7B_Lora/v2-20250608-171618/checkpoint-324" # LoRA 模型路径
|
| 21 |
+
DATA_FILE = "/root/ms-swift/silence_overlaps/test" # 测试数据文件
|
| 22 |
+
OUTPUT_DIR = "omini_inference_7B_overlap5sVal_SFT_allset.json" # 推理结果输出目录
|
| 23 |
+
|
| 24 |
+
model_path: Optional[str] = field(default=MODEL_PATH, metadata={"help": "base model dir"})
|
| 25 |
+
lora_path: Optional[str] = field(default=LORA_PATH, metadata={"help": "lora model dir"})
|
| 26 |
+
out_file: Optional[str] = field(default=OUTPUT_DIR, metadata={"help": "output file for test"})
|
| 27 |
+
data_dir: Optional[str] = field(default=DATA_FILE, metadata={"help": "test data directory"})
|
| 28 |
+
DEVICE: Optional[str] = field(default="cuda:0", metadata={"help": "device to use"})
|
| 29 |
+
force: Optional[bool] = field(default=False, metadata={"help": "force test"})
|
| 30 |
+
batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for processing"})
|
| 31 |
+
|
| 32 |
+
def __post_init__(self):
|
| 33 |
+
if self.model_path is None:
|
| 34 |
+
raise ValueError("config path should not none")
|
| 35 |
+
if self.data_dir is None:
|
| 36 |
+
raise ValueError("data directory should not be none")
|
| 37 |
+
|
| 38 |
+
def get_prompt_templates():
|
| 39 |
+
prompt_template = (
|
| 40 |
+
"You are an expert at analyzing overlapping speech in conversations. Please analyze the speech dialogue and focus specifically on:\n"
|
| 41 |
+
"Please summarize if any overlaps exceed the 3-second threshold."
|
| 42 |
+
)
|
| 43 |
+
return prompt_template
|
| 44 |
+
|
| 45 |
+
def extract_overall_score(output_str):
|
| 46 |
+
"""从输出中提取<overall score>X</overall score>"""
|
| 47 |
+
score_pattern = r"<overall score>(\d+)</overall score>"
|
| 48 |
+
match = re.search(score_pattern, output_str)
|
| 49 |
+
if match:
|
| 50 |
+
try:
|
| 51 |
+
return int(match.group(1))
|
| 52 |
+
except ValueError:
|
| 53 |
+
pass
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
def main():
|
| 57 |
+
parser = HfArgumentParser(TestArguments)
|
| 58 |
+
data_args = parser.parse_args_into_dataclasses()[0]
|
| 59 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 60 |
+
logging.info("Starting inference with arguments: %s", data_args)
|
| 61 |
+
|
| 62 |
+
if not data_args.force and os.path.exists(data_args.out_file) and os.path.getsize(data_args.out_file) > 0:
|
| 63 |
+
logging.info(f"The {data_args.out_file} exists. Do not regenerate it.")
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
# 设置GPU设备
|
| 67 |
+
device = torch.device(data_args.DEVICE if torch.cuda.is_available() else "cpu")
|
| 68 |
+
logging.info(f"Using device: {device}")
|
| 69 |
+
|
| 70 |
+
# 初始化音频处理器
|
| 71 |
+
logging.info("Loading processor...")
|
| 72 |
+
processor = Qwen2_5OmniProcessor.from_pretrained(data_args.model_path)
|
| 73 |
+
|
| 74 |
+
# 初始化推理引擎
|
| 75 |
+
logging.info("Initializing inference engine...")
|
| 76 |
+
engine = PtEngine(data_args.model_path, adapters=[data_args.lora_path])
|
| 77 |
+
engine.processor = processor
|
| 78 |
+
template = get_template(engine.model.model_meta.template, processor, default_system="You are a helpful assistant.")
|
| 79 |
+
engine.default_template = template
|
| 80 |
+
template.processor = processor
|
| 81 |
+
# 初始化数据集
|
| 82 |
+
logging.info("Initializing dataset from %s", data_args.data_dir)
|
| 83 |
+
dataset = AudioDataset(data_args.data_dir)
|
| 84 |
+
logging.info(f"Dataset loaded successfully with {len(dataset)} samples")
|
| 85 |
+
|
| 86 |
+
# 获取提示模板
|
| 87 |
+
prompt_template = get_prompt_templates()
|
| 88 |
+
|
| 89 |
+
all_outputs = []
|
| 90 |
+
batch_size = data_args.batch_size
|
| 91 |
+
total_batches = (len(dataset) + batch_size - 1) // batch_size
|
| 92 |
+
logging.info(f"Starting batch processing with batch size {batch_size}, total batches: {total_batches}")
|
| 93 |
+
|
| 94 |
+
for i in range(0, len(dataset), batch_size):
|
| 95 |
+
current_batch = i // batch_size + 1
|
| 96 |
+
logging.info(f"Processing batch {current_batch}/{total_batches}")
|
| 97 |
+
|
| 98 |
+
batch_data = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]
|
| 99 |
+
|
| 100 |
+
# Process each sample
|
| 101 |
+
batch_outputs = []
|
| 102 |
+
for bd in batch_data:
|
| 103 |
+
# 构建推理请求
|
| 104 |
+
infer_request = InferRequest(
|
| 105 |
+
messages=bd["prompt"],
|
| 106 |
+
audios=[bd["audio"]]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 设置推理配置
|
| 110 |
+
request_config = RequestConfig(
|
| 111 |
+
max_tokens=512,
|
| 112 |
+
temperature=0,
|
| 113 |
+
do_sample=False,
|
| 114 |
+
num_beams=1
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# 执行推理
|
| 118 |
+
resp_list = engine.infer([infer_request], request_config)
|
| 119 |
+
response = resp_list[0].choices[0].message.content
|
| 120 |
+
batch_outputs.append(response)
|
| 121 |
+
|
| 122 |
+
all_outputs.extend(batch_outputs)
|
| 123 |
+
logging.info(f"Completed batch {current_batch}/{total_batches}")
|
| 124 |
+
|
| 125 |
+
final_output = []
|
| 126 |
+
correct_count = 0
|
| 127 |
+
total_count = 0
|
| 128 |
+
true_positive = 0
|
| 129 |
+
false_positive = 0
|
| 130 |
+
false_negative = 0
|
| 131 |
+
|
| 132 |
+
for input_example, model_output in zip(dataset, all_outputs):
|
| 133 |
+
pred_score = extract_overall_score(model_output)
|
| 134 |
+
gt_score = input_example.get("solution", None)
|
| 135 |
+
|
| 136 |
+
result = {
|
| 137 |
+
"id": input_example.get("id", None),
|
| 138 |
+
"gt_score": gt_score,
|
| 139 |
+
"model_output": model_output,
|
| 140 |
+
"predicted_score": pred_score
|
| 141 |
+
}
|
| 142 |
+
final_output.append(result)
|
| 143 |
+
|
| 144 |
+
if pred_score is not None and gt_score is not None:
|
| 145 |
+
total_count += 1
|
| 146 |
+
if pred_score == gt_score:
|
| 147 |
+
correct_count += 1
|
| 148 |
+
true_positive += 1
|
| 149 |
+
else:
|
| 150 |
+
false_positive += 1
|
| 151 |
+
false_negative += 1
|
| 152 |
+
|
| 153 |
+
accuracy = correct_count / total_count if total_count > 0 else 0
|
| 154 |
+
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
|
| 155 |
+
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
|
| 156 |
+
|
| 157 |
+
# 添加准确率指标到最终输出
|
| 158 |
+
metrics = {
|
| 159 |
+
"accuracy": accuracy,
|
| 160 |
+
"precision": precision,
|
| 161 |
+
"recall": recall,
|
| 162 |
+
"correct_count": correct_count,
|
| 163 |
+
"total_count": total_count
|
| 164 |
+
}
|
| 165 |
+
final_output.append({"metrics": metrics})
|
| 166 |
+
|
| 167 |
+
logging.info("Saving results to %s", data_args.out_file)
|
| 168 |
+
with open(data_args.out_file, "w") as f:
|
| 169 |
+
json.dump(final_output, f, indent=2)
|
| 170 |
+
|
| 171 |
+
logging.info(f"Results saved successfully.")
|
| 172 |
+
logging.info(f"准确率: {accuracy:.4f} ({correct_count}/{total_count})")
|
| 173 |
+
logging.info(f"召回率: {recall:.4f}")
|
| 174 |
+
logging.info(f"精确率: {precision:.4f}")
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
main()
|
train.sh
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 3 |
+
--model /root/autodl-tmp/output_7B_FULL_4JOB/v2-20250621-170947/checkpoint-608 \
|
| 4 |
+
--dataset ./dataset_newCOTSFT1_filtered_90.0s_resampled_16000.jsonl \
|
| 5 |
+
--model_type qwen2_5_omni\
|
| 6 |
+
--train_type full \
|
| 7 |
+
--output_dir /root/autodl-tmp/output_7B_FULL_cotSFT \
|
| 8 |
+
--torch_dtype bfloat16 \
|
| 9 |
+
--learning_rate 1e-4 \
|
| 10 |
+
--num_train_epochs 2 \
|
| 11 |
+
--freeze_vit false \
|
| 12 |
+
--freeze_aligner false \
|
| 13 |
+
--per_device_train_batch_size 1 \
|
| 14 |
+
--per_device_eval_batch_size 1 \
|
| 15 |
+
# ...
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
#CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 19 |
+
# --model /root/autodl-tmp/Qwen2.5-Omni-7B \
|
| 20 |
+
# --dataset /root/ms-swift/dataset_cotSFT.json \
|
| 21 |
+
# --model_type qwen2_5_omni\
|
| 22 |
+
# --train_type lora \
|
| 23 |
+
# --output_dir /root/autodl-tmp/output_7B_Lora_cotSFT \
|
| 24 |
+
# --torch_dtype bfloat16 \
|
| 25 |
+
# --learning_rate 1e-4 \
|
| 26 |
+
# --lora_rank 8 \
|
| 27 |
+
# --lora_alpha 32 \
|
| 28 |
+
# --target_modules all-linear \
|
| 29 |
+
# --num_train_epochs 3 \
|
| 30 |
+
# --freeze_vit false \
|
| 31 |
+
# --freeze_aligner false \
|
| 32 |
+
# --per_device_train_batch_size 3 \
|
| 33 |
+
# --per_device_eval_batch_size 1 \
|
| 34 |
+
# ...
|
| 35 |
+
# # 8*A100
|
| 36 |
+
# NPROC_PER_NODE=8 \
|
| 37 |
+
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 38 |
+
# swift pt \
|
| 39 |
+
# --model Qwen/Qwen2.5-7B \
|
| 40 |
+
# --dataset swift/chinese-c4 \
|
| 41 |
+
# --streaming true \
|
| 42 |
+
# --train_type full \
|
| 43 |
+
# --deepspeed zero2 \
|
| 44 |
+
# --output_dir output \
|
| 45 |
+
# --max_steps 10000 \
|
| 46 |
+
# ...
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# --lora_rank 8 \
|
| 51 |
+
# --lora_alpha 32 \
|
| 52 |
+
# --target_modules all-linear \
|
| 53 |
+
# --gradient_accumulation_steps 16 \
|
| 54 |
+
# --eval_steps 50 \
|
| 55 |
+
# --save_steps 50 \
|
| 56 |
+
# --save_total_limit 2 \
|
| 57 |
+
# --logging_steps 5 \
|
| 58 |
+
# --max_length 2048 \
|
| 59 |
+
# --output_dir output \
|
| 60 |
+
# --system 'You are a helpful assistant.' \
|
| 61 |
+
# --warmup_ratio 0.05 \
|
| 62 |
+
# --dataloader_num_workers 4 \
|
| 63 |
+
# --model_author swift \
|
| 64 |
+
# --model_name swift-robot
|