Upload folder using huggingface_hub
Browse files- .gitattributes +35 -35
- README.md +65 -0
- config.json +75 -0
- configuration_dualvitok.py +153 -0
- configuration_movqgan.py +90 -0
- configuration_qwen2vit.py +249 -0
- image_processing_dualvitok.py +48 -0
- image_processing_movqgan.py +428 -0
- image_processing_qwen2vit.py +476 -0
- image_utils.py +812 -0
- modeling_dualvitok.py +653 -0
- modeling_movqgan.py +828 -0
- modeling_qwen2vit.py +842 -0
- modeling_rope_utils.py +561 -0
- preprocessor_config.json +29 -0
- processing_qwen2vit.py +186 -0
- pytorch_model.bin +3 -0
- sdxl_decoder_pipe.py +901 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DualViTok
|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
|
| 5 |
+
<img src="https://illume-unified-mllm.github.io/static/images/logo.png" width="100em"></img>
|
| 6 |
+
|
| 7 |
+
📄 [Paper](https://arxiv.org/abs/2504.01934) |
|
| 8 |
+
🌐 [Project-Page](https://illume-unified-mllm.github.io/) |
|
| 9 |
+
📦 [Github](https://github.com/illume-unified-mllm/ILLUME_plus) |
|
| 10 |
+
|
| 11 |
+
</div>
|
| 12 |
+
|
| 13 |
+
## Introduction
|
| 14 |
+
|
| 15 |
+
**DualViTok**, Dual Vision Tokenizer, is a dual-branch vision tokenizer designed to capture both deep semantics and fine-grained textures. It is proposed in [ILLUME+](https://arxiv.org/abs/2504.01934). The semantic branch utilizes a pre-trained text-aligned vision encoder for semantic feature extraction, supervised by feature reconstruction loss. In parallel, the pixel branch integrates quantized features from both the semantic encoder and a CNN-based pixel encoder to enhance pixel-level reconstruction. To improve robustness against incorrect token predictions in autoregressive generation, we introduce noise injection during training by randomly perturbing visual tokens. Despite its simplicity, DualViTok is specifically designed for unified models, ensuring both semantic and texture preservation while maintaining robust token decoding.
|
| 16 |
+
|
| 17 |
+
<div align="center">
|
| 18 |
+
<img src="https://illume-unified-mllm.github.io/static/images/tokenizer_framework.png" width="80%"></img>
|
| 19 |
+
</div>
|
| 20 |
+
|
| 21 |
+
## Quickstart for Autoencoding
|
| 22 |
+
```python
|
| 23 |
+
from PIL import Image
|
| 24 |
+
import torch
|
| 25 |
+
from transformers import AutoModel, AutoImageProcessor
|
| 26 |
+
|
| 27 |
+
MODEL_HUB = "ILLUME-MLLM/dualvitok/"
|
| 28 |
+
|
| 29 |
+
model = AutoModel.from_pretrained(MODEL_HUB, trust_remote_code=True).eval().cuda()
|
| 30 |
+
processor = AutoImageProcessor.from_pretrained(MODEL_HUB, trust_remote_code=True)
|
| 31 |
+
|
| 32 |
+
# load the diffusion decoder.
|
| 33 |
+
# diffusion_decoder = model.build_sdxl_decoder('ILLUME-MLLM/dualvitok-sdxl-decoder')
|
| 34 |
+
|
| 35 |
+
# TODO: you need to modify the path here
|
| 36 |
+
IMAGE_PATH = "YOUR_IMAGE_PATH"
|
| 37 |
+
|
| 38 |
+
image = Image.open(IMAGE_PATH)
|
| 39 |
+
|
| 40 |
+
image = processor(image, return_tensors="pt")["pixel_values"]
|
| 41 |
+
image = image.unsqueeze(0).cuda()
|
| 42 |
+
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
(quant_semantic, diff_semantic, indices_semantic, _), \
|
| 45 |
+
(quant_pixel, diff_pixel, indices_pixel) = model.encode(image)
|
| 46 |
+
|
| 47 |
+
recon = model.decode(quant_semantic, quant_pixel)
|
| 48 |
+
|
| 49 |
+
# decode from the codes.
|
| 50 |
+
# recon = model.decode_code(indices_semantic, indices_pixel)
|
| 51 |
+
|
| 52 |
+
print(recon.shape)
|
| 53 |
+
recon_image = processor.postprocess(recon)["pixel_values"][0]
|
| 54 |
+
recon_image.save("recon_image.png")
|
| 55 |
+
|
| 56 |
+
# diffusion decoder only support 11 resolution. Check here `diffusion_decoder.resolution_group`.
|
| 57 |
+
# diffusion_recon = diffusion_decoder(# use vq_indices or vq_embeds
|
| 58 |
+
# vq_indices=(indices_semantic, indices_pixel),
|
| 59 |
+
# vq_embeds=(quant_semantic, quant_pixel),
|
| 60 |
+
# height = height * 2,
|
| 61 |
+
# width = width * 2,
|
| 62 |
+
# num_inference_steps = 50,
|
| 63 |
+
# guidance_scale = 1.5,)
|
| 64 |
+
# diffusion_recon.images[0].save("diffusion_recon_image.png")
|
| 65 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoConfig": "configuration_dualvitok.DualViTokConfig",
|
| 4 |
+
"AutoModel": "modeling_dualvitok.DualViTok"
|
| 5 |
+
},
|
| 6 |
+
"architectures": [
|
| 7 |
+
"DualViTok"
|
| 8 |
+
],
|
| 9 |
+
"semantic_encoder": {
|
| 10 |
+
"pretrained_semantic_encoder": "Emova-ollm/qwen2vit600m",
|
| 11 |
+
"z_channels": 32,
|
| 12 |
+
"num_blocks": 4,
|
| 13 |
+
"out_layer": "linear",
|
| 14 |
+
"embed_dim": 1280,
|
| 15 |
+
"target_mlp": "norm"
|
| 16 |
+
},
|
| 17 |
+
"semantic_decoder": {
|
| 18 |
+
"z_channels": 32,
|
| 19 |
+
"num_blocks": 4,
|
| 20 |
+
"embed_dim": 1280,
|
| 21 |
+
"out_layer": "linear_norm",
|
| 22 |
+
"out_channels": 3584
|
| 23 |
+
},
|
| 24 |
+
"semantic_quantizer_type": "simvq",
|
| 25 |
+
"pixel_quantizer_type": "simvq",
|
| 26 |
+
"semantic_quantizer_codebook_size": 32768,
|
| 27 |
+
"pixel_quantizer_codebook_size": 98304,
|
| 28 |
+
"attn_implementation": "eager",
|
| 29 |
+
"pixel_encoder": {
|
| 30 |
+
"codebook_size": 98304,
|
| 31 |
+
"embed_dim": 32,
|
| 32 |
+
"z_channels": 32,
|
| 33 |
+
"double_z": false,
|
| 34 |
+
"in_channels": 3,
|
| 35 |
+
"out_channels": 3,
|
| 36 |
+
"ch": 128,
|
| 37 |
+
"ch_mult": [
|
| 38 |
+
1,
|
| 39 |
+
1,
|
| 40 |
+
2,
|
| 41 |
+
2,
|
| 42 |
+
4
|
| 43 |
+
],
|
| 44 |
+
"num_res_blocks": 2,
|
| 45 |
+
"attn_resolutions": [
|
| 46 |
+
4
|
| 47 |
+
],
|
| 48 |
+
"dropout": 0.0,
|
| 49 |
+
"use_dc_up_down_blocks": true
|
| 50 |
+
},
|
| 51 |
+
"pixel_decoder": {
|
| 52 |
+
"codebook_size": 98304,
|
| 53 |
+
"embed_dim": 64,
|
| 54 |
+
"z_channels": 64,
|
| 55 |
+
"double_z": false,
|
| 56 |
+
"in_channels": 3,
|
| 57 |
+
"out_channels": 3,
|
| 58 |
+
"ch": 384,
|
| 59 |
+
"ch_mult": [
|
| 60 |
+
1,
|
| 61 |
+
1,
|
| 62 |
+
2,
|
| 63 |
+
2,
|
| 64 |
+
4
|
| 65 |
+
],
|
| 66 |
+
"num_res_blocks": 2,
|
| 67 |
+
"attn_resolutions": [
|
| 68 |
+
4
|
| 69 |
+
],
|
| 70 |
+
"dropout": 0.0,
|
| 71 |
+
"use_dc_up_down_blocks": true
|
| 72 |
+
},
|
| 73 |
+
"torch_dtype": "float16",
|
| 74 |
+
"transformers_version": "4.44.2"
|
| 75 |
+
}
|
configuration_dualvitok.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
from transformers.utils import logging
|
| 6 |
+
|
| 7 |
+
from .configuration_movqgan import MoVQConfig
|
| 8 |
+
from .modeling_rope_utils import rope_config_validation
|
| 9 |
+
from .configuration_qwen2vit import Qwen2VLVisionConfig
|
| 10 |
+
|
| 11 |
+
logger = logging.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SemanticEncoderConfig(PretrainedConfig):
|
| 15 |
+
model_type = "DualViTokSemanticEncoder"
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
pretrained_semantic_encoder='Emova-ollm/qwen2vit600m',
|
| 20 |
+
z_channels=32,
|
| 21 |
+
num_blocks=4,
|
| 22 |
+
embed_dim=1280,
|
| 23 |
+
out_layer='linear',
|
| 24 |
+
target_mlp='norm',
|
| 25 |
+
**kwargs
|
| 26 |
+
):
|
| 27 |
+
super().__init__(**kwargs)
|
| 28 |
+
self.pretrained_semantic_encoder = pretrained_semantic_encoder
|
| 29 |
+
self.z_channels = z_channels
|
| 30 |
+
self.num_blocks = num_blocks
|
| 31 |
+
self.out_layer = out_layer
|
| 32 |
+
self.embed_dim = embed_dim
|
| 33 |
+
self.target_mlp = target_mlp
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SemanticDecoderConfig(PretrainedConfig):
|
| 37 |
+
model_type = "DualViTokSemanticDecoder"
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
z_channels=32,
|
| 42 |
+
num_blocks=4,
|
| 43 |
+
embed_dim=1280,
|
| 44 |
+
out_layer='linear_norm',
|
| 45 |
+
out_channels=3584,
|
| 46 |
+
**kwargs
|
| 47 |
+
):
|
| 48 |
+
super().__init__(**kwargs)
|
| 49 |
+
self.z_channels = z_channels
|
| 50 |
+
self.num_blocks = num_blocks
|
| 51 |
+
self.embed_dim = embed_dim
|
| 52 |
+
self.out_layer = out_layer
|
| 53 |
+
self.out_channels = out_channels
|
| 54 |
+
|
| 55 |
+
class DualViTokConfig(PretrainedConfig):
|
| 56 |
+
r"""
|
| 57 |
+
This is the configuration class to store the configuration of a [`DualViTok`]. It is used to instantiate an video movq
|
| 58 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 59 |
+
defaults will yield a configuration to the VQ model presented in paper.
|
| 60 |
+
|
| 61 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 62 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
codebook_size (`int`, *optional*, defaults to 32768):
|
| 67 |
+
Codebook size of the VQ model.
|
| 68 |
+
embed_dim (`int`, *optional*, defaults to 4):
|
| 69 |
+
Dimension of the quantized vector in codebook.
|
| 70 |
+
z_channels (`int`, *optional*, defaults to 4):
|
| 71 |
+
Dimension of the output channel of encoder and the input channel of decoder
|
| 72 |
+
double_z (`bool`, *optional*, defaults to False):
|
| 73 |
+
Whether double the output dim of the encoder.
|
| 74 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 75 |
+
Input channel of encoder.
|
| 76 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 77 |
+
Output channel of decoder.
|
| 78 |
+
ch (`int`, *optional*, defaults to 256):
|
| 79 |
+
Basic channel number of the intermediate blocks.
|
| 80 |
+
ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
|
| 81 |
+
Channel scaling factor of the intermediate blocks.
|
| 82 |
+
num_res_blocks (`int`, *optional*, defaults to 2):
|
| 83 |
+
Residual block number in each stage.
|
| 84 |
+
attn_resolutions (`List[int]`, *optional*, defaults to 3):
|
| 85 |
+
Stage indices to apply attention.
|
| 86 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 87 |
+
Dropout probability.
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
>>> from transformers import DualViTok, DualViTokConfig
|
| 91 |
+
|
| 92 |
+
>>> # Initializing a video VQ model of configuration
|
| 93 |
+
>>> configuration = DualViTokConfig()
|
| 94 |
+
|
| 95 |
+
>>> # Initializing a model from the VQ model style configuration
|
| 96 |
+
>>> model = DualViTok(configuration)
|
| 97 |
+
|
| 98 |
+
>>> # Accessing the model configuration
|
| 99 |
+
>>> configuration = model.config
|
| 100 |
+
```"""
|
| 101 |
+
|
| 102 |
+
model_type = "DualViTok"
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
semantic_encoder=None,
|
| 107 |
+
semantic_decoder=None,
|
| 108 |
+
pixel_encoder=None,
|
| 109 |
+
pixel_decoder=None,
|
| 110 |
+
semantic_quantizer_type='simvq',
|
| 111 |
+
pixel_quantizer_type='simvq',
|
| 112 |
+
semantic_quantizer_codebook_size=32768,
|
| 113 |
+
pixel_quantizer_codebook_size=98304,
|
| 114 |
+
attn_implementation='sdpa',
|
| 115 |
+
**kwargs,
|
| 116 |
+
):
|
| 117 |
+
super().__init__(**kwargs)
|
| 118 |
+
if semantic_encoder is None:
|
| 119 |
+
self.semantic_encoder = SemanticEncoderConfig()
|
| 120 |
+
else:
|
| 121 |
+
self.semantic_encoder = SemanticEncoderConfig(**semantic_encoder)
|
| 122 |
+
if semantic_decoder is None:
|
| 123 |
+
self.semantic_decoder = SemanticEncoderConfig()
|
| 124 |
+
else:
|
| 125 |
+
self.semantic_decoder = SemanticEncoderConfig(**semantic_decoder)
|
| 126 |
+
|
| 127 |
+
self.semantic_quantizer_type = semantic_quantizer_type
|
| 128 |
+
self.pixel_quantizer_type = pixel_quantizer_type
|
| 129 |
+
self.semantic_quantizer_codebook_size = semantic_quantizer_codebook_size
|
| 130 |
+
self.pixel_quantizer_codebook_size = pixel_quantizer_codebook_size
|
| 131 |
+
|
| 132 |
+
if pixel_encoder is None:
|
| 133 |
+
self.pixel_encoder = MoVQConfig()
|
| 134 |
+
else:
|
| 135 |
+
self.pixel_encoder = MoVQConfig(**pixel_encoder)
|
| 136 |
+
|
| 137 |
+
self.pixel_decoder = self.pixel_encoder if pixel_decoder is None else MoVQConfig(**pixel_decoder)
|
| 138 |
+
|
| 139 |
+
self.attn_implementation = attn_implementation
|
| 140 |
+
|
| 141 |
+
@classmethod
|
| 142 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], attn_implementation='sdpa', **kwargs) -> "PretrainedConfig":
|
| 143 |
+
cls._set_token_in_kwargs(kwargs)
|
| 144 |
+
|
| 145 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 146 |
+
|
| 147 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 148 |
+
logger.warning(
|
| 149 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 150 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return cls.from_dict(config_dict, attn_implementation=attn_implementation, **kwargs)
|
configuration_movqgan.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" MoVQ model configuration """
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
+
from transformers.utils import logging
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
logger = logging.get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MoVQConfig(PretrainedConfig):
|
| 13 |
+
r"""
|
| 14 |
+
This is the configuration class to store the configuration of a [`MoVQ`]. It is used to instantiate an video movq
|
| 15 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 16 |
+
defaults will yield a configuration to the VQ model presented in paper.
|
| 17 |
+
|
| 18 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 19 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
codebook_size (`int`, *optional*, defaults to 32768):
|
| 24 |
+
Codebook size of the VQ model.
|
| 25 |
+
embed_dim (`int`, *optional*, defaults to 4):
|
| 26 |
+
Dimension of the quantized vector in codebook.
|
| 27 |
+
z_channels (`int`, *optional*, defaults to 4):
|
| 28 |
+
Dimension of the output channel of encoder and the input channel of decoder
|
| 29 |
+
double_z (`bool`, *optional*, defaults to False):
|
| 30 |
+
Whether double the output dim of the encoder.
|
| 31 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 32 |
+
Input channel of encoder.
|
| 33 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 34 |
+
Output channel of decoder.
|
| 35 |
+
ch (`int`, *optional*, defaults to 256):
|
| 36 |
+
Basic channel number of the intermediate blocks.
|
| 37 |
+
ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
|
| 38 |
+
Channel scaling factor of the intermediate blocks.
|
| 39 |
+
num_res_blocks (`int`, *optional*, defaults to 2):
|
| 40 |
+
Residual block number in each stage.
|
| 41 |
+
attn_resolutions (`List[int]`, *optional*, defaults to 3):
|
| 42 |
+
Stage indices to apply attention.
|
| 43 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 44 |
+
Dropout probability.
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
>>> from transformers import MoVQ, MoVQConfig
|
| 48 |
+
|
| 49 |
+
>>> # Initializing a video VQ model of configuration
|
| 50 |
+
>>> configuration = MoVQConfig()
|
| 51 |
+
|
| 52 |
+
>>> # Initializing a model from the VQ model style configuration
|
| 53 |
+
>>> model = MoVQModel(configuration)
|
| 54 |
+
|
| 55 |
+
>>> # Accessing the model configuration
|
| 56 |
+
>>> configuration = model.config
|
| 57 |
+
```"""
|
| 58 |
+
|
| 59 |
+
model_type = "MoVQ"
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
codebook_size: int = 32768,
|
| 64 |
+
embed_dim: int = 4,
|
| 65 |
+
z_channels: int = 4,
|
| 66 |
+
double_z: bool = False,
|
| 67 |
+
in_channels: int = 3,
|
| 68 |
+
out_channels: int = 3,
|
| 69 |
+
ch: int = 256,
|
| 70 |
+
ch_mult: List[int] = [1, 2, 2, 4],
|
| 71 |
+
num_res_blocks: int = 2,
|
| 72 |
+
attn_resolutions: List[int] = [3],
|
| 73 |
+
dropout: float = 0.0,
|
| 74 |
+
use_dc_up_down_blocks=False,
|
| 75 |
+
**kwargs,
|
| 76 |
+
):
|
| 77 |
+
super().__init__(**kwargs)
|
| 78 |
+
|
| 79 |
+
self.codebook_size = codebook_size
|
| 80 |
+
self.embed_dim = embed_dim
|
| 81 |
+
self.z_channels = z_channels
|
| 82 |
+
self.double_z = double_z
|
| 83 |
+
self.in_channels = in_channels
|
| 84 |
+
self.out_channels = out_channels
|
| 85 |
+
self.ch = ch
|
| 86 |
+
self.ch_mult = ch_mult
|
| 87 |
+
self.num_res_blocks = num_res_blocks
|
| 88 |
+
self.attn_resolutions = attn_resolutions
|
| 89 |
+
self.dropout = dropout
|
| 90 |
+
self.use_dc_up_down_blocks = use_dc_up_down_blocks
|
configuration_qwen2vit.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Qwen2VL model configuration"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from typing import Union
|
| 19 |
+
|
| 20 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 21 |
+
from transformers.utils import logging
|
| 22 |
+
from .modeling_rope_utils import rope_config_validation
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Qwen2VLVisionConfig(PretrainedConfig):
|
| 28 |
+
model_type = "qwen2_vl"
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
depth=32,
|
| 33 |
+
embed_dim=1280,
|
| 34 |
+
hidden_size=3584,
|
| 35 |
+
hidden_act="quick_gelu",
|
| 36 |
+
mlp_ratio=4,
|
| 37 |
+
num_heads=16,
|
| 38 |
+
in_channels=3,
|
| 39 |
+
patch_size=14,
|
| 40 |
+
spatial_merge_size=2,
|
| 41 |
+
temporal_patch_size=2,
|
| 42 |
+
attn_implementation='eager',
|
| 43 |
+
init_weights=False,
|
| 44 |
+
**kwargs,
|
| 45 |
+
):
|
| 46 |
+
super().__init__(**kwargs)
|
| 47 |
+
|
| 48 |
+
self.depth = depth
|
| 49 |
+
self.embed_dim = embed_dim
|
| 50 |
+
self.hidden_size = hidden_size
|
| 51 |
+
self.hidden_act = hidden_act
|
| 52 |
+
self.mlp_ratio = mlp_ratio
|
| 53 |
+
self.num_heads = num_heads
|
| 54 |
+
self.in_channels = in_channels
|
| 55 |
+
self.patch_size = patch_size
|
| 56 |
+
self.spatial_merge_size = spatial_merge_size
|
| 57 |
+
self.temporal_patch_size = temporal_patch_size
|
| 58 |
+
self.attn_implementation = attn_implementation if attn_implementation else 'eager'
|
| 59 |
+
|
| 60 |
+
self.init_weights = init_weights
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 64 |
+
cls._set_token_in_kwargs(kwargs)
|
| 65 |
+
|
| 66 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 67 |
+
|
| 68 |
+
# if config_dict.get("model_type") == "qwen2_vl":
|
| 69 |
+
# config_dict = config_dict["vision_config"]
|
| 70 |
+
|
| 71 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 72 |
+
logger.warning(
|
| 73 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 74 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Qwen2VLConfig(PretrainedConfig):
|
| 81 |
+
r"""
|
| 82 |
+
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
|
| 83 |
+
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 84 |
+
with the defaults will yield a similar configuration to that of
|
| 85 |
+
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
| 86 |
+
|
| 87 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 88 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
vocab_size (`int`, *optional*, defaults to 152064):
|
| 93 |
+
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
|
| 94 |
+
`inputs_ids` passed when calling [`Qwen2VLModel`]
|
| 95 |
+
hidden_size (`int`, *optional*, defaults to 8192):
|
| 96 |
+
Dimension of the hidden representations.
|
| 97 |
+
intermediate_size (`int`, *optional*, defaults to 29568):
|
| 98 |
+
Dimension of the MLP representations.
|
| 99 |
+
num_hidden_layers (`int`, *optional*, defaults to 80):
|
| 100 |
+
Number of hidden layers in the Transformer encoder.
|
| 101 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
| 102 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 103 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 104 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 105 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 106 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 107 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 108 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 109 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
| 110 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 111 |
+
The non-linear activation function (function or string) in the decoder.
|
| 112 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 113 |
+
The maximum sequence length that this model might ever be used with.
|
| 114 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 115 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 116 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 117 |
+
The epsilon used by the rms normalization layers.
|
| 118 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 119 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 120 |
+
relevant if `config.is_decoder=True`.
|
| 121 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 122 |
+
Whether the model's input and output word embeddings should be tied.
|
| 123 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
| 124 |
+
The base period of the RoPE embeddings.
|
| 125 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 126 |
+
Whether to use sliding window attention.
|
| 127 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 128 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 129 |
+
max_window_layers (`int`, *optional*, defaults to 80):
|
| 130 |
+
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
| 131 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 132 |
+
The dropout ratio for the attention probabilities.
|
| 133 |
+
vision_config (`Dict`, *optional*):
|
| 134 |
+
The config for the visual encoder initialization.
|
| 135 |
+
rope_scaling (`Dict`, *optional*):
|
| 136 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 137 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 138 |
+
accordingly.
|
| 139 |
+
Expected contents:
|
| 140 |
+
`rope_type` (`str`):
|
| 141 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 142 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 143 |
+
`factor` (`float`, *optional*):
|
| 144 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 145 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 146 |
+
original maximum pre-trained length.
|
| 147 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 148 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 149 |
+
pretraining.
|
| 150 |
+
`attention_factor` (`float`, *optional*):
|
| 151 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 152 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 153 |
+
`factor` field to infer the suggested value.
|
| 154 |
+
`beta_fast` (`float`, *optional*):
|
| 155 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 156 |
+
ramp function. If unspecified, it defaults to 32.
|
| 157 |
+
`beta_slow` (`float`, *optional*):
|
| 158 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 159 |
+
ramp function. If unspecified, it defaults to 1.
|
| 160 |
+
`short_factor` (`List[float]`, *optional*):
|
| 161 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 162 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 163 |
+
size divided by the number of attention heads divided by 2
|
| 164 |
+
`long_factor` (`List[float]`, *optional*):
|
| 165 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 166 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 167 |
+
size divided by the number of attention heads divided by 2
|
| 168 |
+
`low_freq_factor` (`float`, *optional*):
|
| 169 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 170 |
+
`high_freq_factor` (`float`, *optional*):
|
| 171 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 172 |
+
|
| 173 |
+
```python
|
| 174 |
+
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
|
| 175 |
+
|
| 176 |
+
>>> # Initializing a Qwen2VL style configuration
|
| 177 |
+
>>> configuration = Qwen2VLConfig()
|
| 178 |
+
|
| 179 |
+
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
| 180 |
+
>>> model = Qwen2VLForConditionalGeneration(configuration)
|
| 181 |
+
|
| 182 |
+
>>> # Accessing the model configuration
|
| 183 |
+
>>> configuration = model.config
|
| 184 |
+
```"""
|
| 185 |
+
|
| 186 |
+
model_type = "qwen2_vl"
|
| 187 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 188 |
+
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
vocab_size=152064,
|
| 192 |
+
hidden_size=8192,
|
| 193 |
+
intermediate_size=29568,
|
| 194 |
+
num_hidden_layers=80,
|
| 195 |
+
num_attention_heads=64,
|
| 196 |
+
num_key_value_heads=8,
|
| 197 |
+
hidden_act="silu",
|
| 198 |
+
max_position_embeddings=32768,
|
| 199 |
+
initializer_range=0.02,
|
| 200 |
+
rms_norm_eps=1e-05,
|
| 201 |
+
use_cache=True,
|
| 202 |
+
tie_word_embeddings=False,
|
| 203 |
+
rope_theta=1000000.0,
|
| 204 |
+
use_sliding_window=False,
|
| 205 |
+
sliding_window=4096,
|
| 206 |
+
max_window_layers=80,
|
| 207 |
+
attention_dropout=0.0,
|
| 208 |
+
vision_config=None,
|
| 209 |
+
rope_scaling=None,
|
| 210 |
+
**kwargs,
|
| 211 |
+
):
|
| 212 |
+
if isinstance(vision_config, dict):
|
| 213 |
+
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
| 214 |
+
elif vision_config is None:
|
| 215 |
+
self.vision_config = Qwen2VLVisionConfig()
|
| 216 |
+
|
| 217 |
+
self.vocab_size = vocab_size
|
| 218 |
+
self.max_position_embeddings = max_position_embeddings
|
| 219 |
+
self.hidden_size = hidden_size
|
| 220 |
+
self.intermediate_size = intermediate_size
|
| 221 |
+
self.num_hidden_layers = num_hidden_layers
|
| 222 |
+
self.num_attention_heads = num_attention_heads
|
| 223 |
+
self.use_sliding_window = use_sliding_window
|
| 224 |
+
self.sliding_window = sliding_window
|
| 225 |
+
self.max_window_layers = max_window_layers
|
| 226 |
+
|
| 227 |
+
# for backward compatibility
|
| 228 |
+
if num_key_value_heads is None:
|
| 229 |
+
num_key_value_heads = num_attention_heads
|
| 230 |
+
|
| 231 |
+
self.num_key_value_heads = num_key_value_heads
|
| 232 |
+
self.hidden_act = hidden_act
|
| 233 |
+
self.initializer_range = initializer_range
|
| 234 |
+
self.rms_norm_eps = rms_norm_eps
|
| 235 |
+
self.use_cache = use_cache
|
| 236 |
+
self.rope_theta = rope_theta
|
| 237 |
+
self.attention_dropout = attention_dropout
|
| 238 |
+
self.rope_scaling = rope_scaling
|
| 239 |
+
|
| 240 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 241 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 242 |
+
# and change type from 'mrope' to 'default'
|
| 243 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 244 |
+
if self.rope_scaling["type"] == "mrope":
|
| 245 |
+
self.rope_scaling["type"] = "default"
|
| 246 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 247 |
+
rope_config_validation(self)
|
| 248 |
+
|
| 249 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
image_processing_dualvitok.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
|
| 3 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
| 4 |
+
|
| 5 |
+
from .image_processing_movqgan import MoVQImageProcessor
|
| 6 |
+
|
| 7 |
+
logger = logging.get_logger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DualViTokImageProcessor(MoVQImageProcessor):
|
| 11 |
+
r"""
|
| 12 |
+
Constructs a DualViTok image processor that dynamically resizes images based on the original images.
|
| 13 |
+
This image processor is based on MoVQImageProcessor with spatial_factor of 16.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 17 |
+
Whether to resize the image's (height, width) dimensions.
|
| 18 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 19 |
+
Resampling filter to use when resizing the image.
|
| 20 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 21 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
| 22 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 23 |
+
Scale factor to use if rescaling the image.
|
| 24 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 25 |
+
Whether to normalize the image.
|
| 26 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 27 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 28 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 29 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 30 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 31 |
+
Whether to convert the image to RGB.
|
| 32 |
+
min_pixels (`int`, *optional*, defaults to `512 * 512`):
|
| 33 |
+
The min pixels of the image to resize the image.
|
| 34 |
+
max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
|
| 35 |
+
The max pixels of the image to resize the image.
|
| 36 |
+
spatial_factor (`int`, *optional*, defautls to 8):
|
| 37 |
+
The spatial downsample factor the image will be downsampled in feature extracting phase
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
model_input_names = ["pixel_values"]
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
*args,
|
| 45 |
+
spatial_factor: int = 16,
|
| 46 |
+
**kwargs,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__(*args, spatial_factor=spatial_factor, **kwargs)
|
image_processing_movqgan.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image processor class for MoVQ."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| 10 |
+
from transformers.image_transforms import (
|
| 11 |
+
convert_to_rgb,
|
| 12 |
+
resize,
|
| 13 |
+
to_channel_dimension_format,
|
| 14 |
+
)
|
| 15 |
+
from transformers.image_utils import (
|
| 16 |
+
IMAGENET_STANDARD_MEAN,
|
| 17 |
+
IMAGENET_STANDARD_STD,
|
| 18 |
+
ChannelDimension,
|
| 19 |
+
ImageInput,
|
| 20 |
+
PILImageResampling,
|
| 21 |
+
get_image_size,
|
| 22 |
+
infer_channel_dimension_format,
|
| 23 |
+
is_scaled_image,
|
| 24 |
+
make_list_of_images,
|
| 25 |
+
to_numpy_array,
|
| 26 |
+
valid_images,
|
| 27 |
+
validate_preprocess_arguments,
|
| 28 |
+
)
|
| 29 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_vision_available():
|
| 36 |
+
from PIL import Image
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def smart_resize(
|
| 40 |
+
height: int, width: int, factor: int = 8, min_pixels: int = 512 * 512, max_pixels: int = 1024 * 1024
|
| 41 |
+
):
|
| 42 |
+
"""Rescales the image so that the following conditions are met:
|
| 43 |
+
|
| 44 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 45 |
+
|
| 46 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 47 |
+
|
| 48 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
# if height < factor or width < factor:
|
| 52 |
+
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
| 53 |
+
# elif max(height, width) / min(height, width) > 5:
|
| 54 |
+
# raise ValueError(
|
| 55 |
+
# f"absolute aspect ratio must be smaller than 5, got {max(height, width) / min(height, width)}"
|
| 56 |
+
# )
|
| 57 |
+
|
| 58 |
+
h_bar = round(height / factor) * factor
|
| 59 |
+
w_bar = round(width / factor) * factor
|
| 60 |
+
if h_bar * w_bar > max_pixels:
|
| 61 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 62 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 63 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 64 |
+
elif h_bar * w_bar < min_pixels:
|
| 65 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 66 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 67 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 68 |
+
|
| 69 |
+
return max(h_bar, factor), max(w_bar, factor)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MoVQImageProcessor(BaseImageProcessor):
|
| 73 |
+
r"""
|
| 74 |
+
Constructs a MoVQ image processor that dynamically resizes images based on the original images.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 78 |
+
Whether to resize the image's (height, width) dimensions.
|
| 79 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 80 |
+
Resampling filter to use when resizing the image.
|
| 81 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 82 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
| 83 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 84 |
+
Scale factor to use if rescaling the image.
|
| 85 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 86 |
+
Whether to normalize the image.
|
| 87 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 88 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 89 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 90 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 91 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 92 |
+
Whether to convert the image to RGB.
|
| 93 |
+
min_pixels (`int`, *optional*, defaults to `512 * 512`):
|
| 94 |
+
The min pixels of the image to resize the image.
|
| 95 |
+
max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
|
| 96 |
+
The max pixels of the image to resize the image.
|
| 97 |
+
spatial_factor (`int`, *optional*, defautls to 8):
|
| 98 |
+
The spatial downsample factor the image will be downsampled in feature extracting phase
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
model_input_names = ["pixel_values"]
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
do_resize: bool = True,
|
| 106 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 107 |
+
do_rescale: bool = True,
|
| 108 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 109 |
+
do_normalize: bool = True,
|
| 110 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 111 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 112 |
+
do_convert_rgb: bool = True,
|
| 113 |
+
min_pixels: int = 32 * 32,
|
| 114 |
+
max_pixels: int = 1024 * 1024,
|
| 115 |
+
spatial_factor: int = 8,
|
| 116 |
+
**kwargs,
|
| 117 |
+
) -> None:
|
| 118 |
+
super().__init__(**kwargs)
|
| 119 |
+
self.do_resize = do_resize
|
| 120 |
+
self.resample = resample
|
| 121 |
+
self.do_rescale = do_rescale
|
| 122 |
+
self.rescale_factor = rescale_factor
|
| 123 |
+
self.do_normalize = do_normalize
|
| 124 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 125 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 126 |
+
self.min_pixels = min_pixels
|
| 127 |
+
self.max_pixels = max_pixels
|
| 128 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
| 129 |
+
self.do_convert_rgb = do_convert_rgb
|
| 130 |
+
self.spatial_factor = spatial_factor
|
| 131 |
+
|
| 132 |
+
def _preprocess(
|
| 133 |
+
self,
|
| 134 |
+
images: ImageInput,
|
| 135 |
+
do_resize: Optional[bool] = None,
|
| 136 |
+
resample: PILImageResampling = None,
|
| 137 |
+
do_rescale: Optional[bool] = None,
|
| 138 |
+
rescale_factor: Optional[float] = None,
|
| 139 |
+
do_normalize: Optional[bool] = None,
|
| 140 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 141 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 142 |
+
do_convert_rgb: Optional[bool] = None,
|
| 143 |
+
spatial_factor: Optional[int] = None,
|
| 144 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 145 |
+
output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
|
| 146 |
+
):
|
| 147 |
+
"""
|
| 148 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
images (`ImageInput`):
|
| 152 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
| 153 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 154 |
+
Whether to resize the image.
|
| 155 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 156 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
| 157 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 158 |
+
Whether to rescale the image.
|
| 159 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 160 |
+
Scale factor to use if rescaling the image.
|
| 161 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 162 |
+
Whether to normalize the image.
|
| 163 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 164 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 165 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 166 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 167 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 168 |
+
Whether to convert the image to RGB.
|
| 169 |
+
spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
|
| 170 |
+
The spatial downsample factor the image will be downsampled in feature extracting phase
|
| 171 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 172 |
+
The channel dimension format for the input image. Can be one of:
|
| 173 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 174 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 175 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 176 |
+
output_data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 177 |
+
The channel dimension format for the output image. Can be one of:
|
| 178 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 179 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 180 |
+
- Unset: Use the channel dimension format of the input image.
|
| 181 |
+
"""
|
| 182 |
+
spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
|
| 183 |
+
|
| 184 |
+
images = make_list_of_images(images)
|
| 185 |
+
if do_convert_rgb:
|
| 186 |
+
images = [convert_to_rgb(image) for image in images]
|
| 187 |
+
|
| 188 |
+
# All transformations expect numpy arrays.
|
| 189 |
+
images = [to_numpy_array(image) for image in images]
|
| 190 |
+
|
| 191 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 192 |
+
logger.warning_once(
|
| 193 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 194 |
+
"pixel_values.append()images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
if input_data_format is None:
|
| 198 |
+
# We assume that all images have the same channel dimension format.
|
| 199 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 200 |
+
|
| 201 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
| 202 |
+
resized_height, resized_width = height, width
|
| 203 |
+
processed_images = []
|
| 204 |
+
for image in images:
|
| 205 |
+
if do_resize:
|
| 206 |
+
resized_height, resized_width = smart_resize(
|
| 207 |
+
height,
|
| 208 |
+
width,
|
| 209 |
+
factor=spatial_factor,
|
| 210 |
+
min_pixels=self.min_pixels,
|
| 211 |
+
max_pixels=self.max_pixels,
|
| 212 |
+
)
|
| 213 |
+
image = resize(
|
| 214 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if do_rescale:
|
| 218 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
| 219 |
+
|
| 220 |
+
if do_normalize:
|
| 221 |
+
image = self.normalize(
|
| 222 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
image = to_channel_dimension_format(image, output_data_format, input_channel_dim=input_data_format)
|
| 226 |
+
processed_images.append(image)
|
| 227 |
+
|
| 228 |
+
image = np.array(processed_images)
|
| 229 |
+
return image
|
| 230 |
+
|
| 231 |
+
def preprocess(
|
| 232 |
+
self,
|
| 233 |
+
images: ImageInput,
|
| 234 |
+
do_resize: Optional[bool] = None,
|
| 235 |
+
resample: PILImageResampling = None,
|
| 236 |
+
do_rescale: Optional[bool] = None,
|
| 237 |
+
rescale_factor: Optional[float] = None,
|
| 238 |
+
do_normalize: Optional[bool] = None,
|
| 239 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 240 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 241 |
+
do_convert_rgb: Optional[bool] = None,
|
| 242 |
+
spatial_factor: Optional[int] = None,
|
| 243 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 244 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 245 |
+
output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Args:
|
| 249 |
+
images (`ImageInput`):
|
| 250 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 251 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 252 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 253 |
+
Whether to resize the image.
|
| 254 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 255 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 256 |
+
has an effect if `do_resize` is set to `True`.
|
| 257 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 258 |
+
Whether to rescale the image.
|
| 259 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 260 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 261 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 262 |
+
Whether to normalize the image.
|
| 263 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 264 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 265 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 266 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 267 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 268 |
+
Whether to convert the image to RGB.
|
| 269 |
+
spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
|
| 270 |
+
The spatial downsample factor the image will be downsampled in feature extracting phase
|
| 271 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 272 |
+
The type of tensors to return. Can be one of:
|
| 273 |
+
- Unset: Return a list of `np.ndarray`.
|
| 274 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 275 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 276 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 277 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 278 |
+
from the input image. Can be one of:
|
| 279 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 280 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 281 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 282 |
+
output_data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 283 |
+
The channel dimension format for the output image. Can be one of:
|
| 284 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 285 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 286 |
+
- Unset: Use the channel dimension format of the input image.
|
| 287 |
+
"""
|
| 288 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 289 |
+
resample = resample if resample is not None else self.resample
|
| 290 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 291 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 292 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 293 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 294 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 295 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 296 |
+
spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
|
| 297 |
+
|
| 298 |
+
images = make_list_of_images(images)
|
| 299 |
+
if images is None or not valid_images(images):
|
| 300 |
+
raise ValueError(
|
| 301 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 302 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
validate_preprocess_arguments(
|
| 306 |
+
rescale_factor=rescale_factor,
|
| 307 |
+
do_normalize=do_normalize,
|
| 308 |
+
image_mean=image_mean,
|
| 309 |
+
image_std=image_std,
|
| 310 |
+
do_resize=do_resize,
|
| 311 |
+
size=self.size,
|
| 312 |
+
resample=resample,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
pixel_values = []
|
| 316 |
+
for image in images:
|
| 317 |
+
norm_image = self._preprocess(
|
| 318 |
+
image,
|
| 319 |
+
do_resize=do_resize,
|
| 320 |
+
resample=resample,
|
| 321 |
+
do_rescale=do_rescale,
|
| 322 |
+
rescale_factor=rescale_factor,
|
| 323 |
+
do_normalize=do_normalize,
|
| 324 |
+
image_mean=image_mean,
|
| 325 |
+
image_std=image_std,
|
| 326 |
+
do_convert_rgb=do_convert_rgb,
|
| 327 |
+
spatial_factor=spatial_factor,
|
| 328 |
+
input_data_format=input_data_format,
|
| 329 |
+
output_data_format=output_data_format,
|
| 330 |
+
)
|
| 331 |
+
pixel_values.extend(norm_image)
|
| 332 |
+
pixel_values = np.array(pixel_values)
|
| 333 |
+
data = {"pixel_values": pixel_values}
|
| 334 |
+
|
| 335 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 336 |
+
|
| 337 |
+
def postprocess(
|
| 338 |
+
self,
|
| 339 |
+
images: ImageInput,
|
| 340 |
+
do_rescale: Optional[bool] = None,
|
| 341 |
+
rescale_factor: Optional[float] = None,
|
| 342 |
+
do_normalize: Optional[bool] = None,
|
| 343 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 344 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 345 |
+
return_tensors: Optional[Union[str, TensorType]] = "PIL.Image.Image",
|
| 346 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 347 |
+
):
|
| 348 |
+
"""
|
| 349 |
+
Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess.
|
| 350 |
+
The parameters should be same as in preprocess.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
images (`ImageInput`):
|
| 354 |
+
Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1.
|
| 355 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 356 |
+
Whether to rescale the image.
|
| 357 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 358 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 359 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 360 |
+
Whether to normalize the image.
|
| 361 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 362 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 363 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 364 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 365 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 366 |
+
The type of tensors to return. Can be one of:
|
| 367 |
+
- Unset: Return a list of `np.ndarray`.
|
| 368 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 369 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 370 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 371 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 372 |
+
from the input image. Can be one of:
|
| 373 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 374 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 375 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 376 |
+
"""
|
| 377 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 378 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 379 |
+
rescale_factor = 1 / rescale_factor
|
| 380 |
+
|
| 381 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 382 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 383 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 384 |
+
image_mean, image_std = self.inverse_meanstd(image_mean, image_std)
|
| 385 |
+
|
| 386 |
+
images = make_list_of_images(images)
|
| 387 |
+
if isinstance(images[0], Image.Image):
|
| 388 |
+
return images if len(images) > 1 else images[0]
|
| 389 |
+
|
| 390 |
+
if input_data_format is None:
|
| 391 |
+
# We assume that all images have the same channel dimension format.
|
| 392 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 393 |
+
|
| 394 |
+
pixel_values = []
|
| 395 |
+
for image in images:
|
| 396 |
+
image = to_numpy_array(image)
|
| 397 |
+
if do_normalize:
|
| 398 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 399 |
+
|
| 400 |
+
if do_rescale:
|
| 401 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
| 402 |
+
image = image.clip(0, 255).astype(np.uint8)
|
| 403 |
+
|
| 404 |
+
if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
|
| 405 |
+
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
|
| 406 |
+
pixel_values.append(Image.fromarray(image))
|
| 407 |
+
else:
|
| 408 |
+
pixel_values.extend(image)
|
| 409 |
+
|
| 410 |
+
data = {"pixel_values": pixel_values}
|
| 411 |
+
return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
|
| 412 |
+
|
| 413 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 414 |
+
|
| 415 |
+
def inverse_meanstd(self, image_mean, image_std):
|
| 416 |
+
image_mean = self.to_tuple(image_mean)
|
| 417 |
+
image_std = self.to_tuple(image_std)
|
| 418 |
+
|
| 419 |
+
rev_image_mean = tuple(-m / s for m, s in zip(image_mean, image_std))
|
| 420 |
+
rev_image_std = tuple(1 / s for s in image_std)
|
| 421 |
+
|
| 422 |
+
return rev_image_mean, rev_image_std
|
| 423 |
+
|
| 424 |
+
def to_tuple(self, value, dim=3):
|
| 425 |
+
if isinstance(value, (int, float)):
|
| 426 |
+
return (value,) * dim
|
| 427 |
+
|
| 428 |
+
return tuple(value)
|
image_processing_qwen2vit.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""Image processor class for Qwen2-VL."""
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
from typing import Dict, List, Optional, Union
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
from torch import nn
|
| 28 |
+
|
| 29 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| 30 |
+
from transformers.image_transforms import (
|
| 31 |
+
convert_to_rgb,
|
| 32 |
+
resize,
|
| 33 |
+
to_channel_dimension_format,
|
| 34 |
+
)
|
| 35 |
+
from .image_utils import (
|
| 36 |
+
OPENAI_CLIP_MEAN,
|
| 37 |
+
OPENAI_CLIP_STD,
|
| 38 |
+
ChannelDimension,
|
| 39 |
+
ImageInput,
|
| 40 |
+
PILImageResampling,
|
| 41 |
+
VideoInput,
|
| 42 |
+
get_image_size,
|
| 43 |
+
infer_channel_dimension_format,
|
| 44 |
+
is_scaled_image,
|
| 45 |
+
is_valid_image,
|
| 46 |
+
make_list_of_images,
|
| 47 |
+
to_numpy_array,
|
| 48 |
+
valid_images,
|
| 49 |
+
validate_preprocess_arguments,
|
| 50 |
+
)
|
| 51 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
logger = logging.get_logger(__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if is_vision_available():
|
| 58 |
+
from PIL import Image
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def make_batched_images(images) -> List[List[ImageInput]]:
|
| 62 |
+
"""
|
| 63 |
+
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
| 67 |
+
The input image.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
list: A list of images.
|
| 71 |
+
"""
|
| 72 |
+
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
| 73 |
+
return [img for img_list in images for img in img_list]
|
| 74 |
+
|
| 75 |
+
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
| 76 |
+
return images
|
| 77 |
+
|
| 78 |
+
elif is_valid_image(images):
|
| 79 |
+
return [images]
|
| 80 |
+
|
| 81 |
+
raise ValueError(f"Could not make batched images from {images}")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
|
| 85 |
+
def make_batched_videos(videos) -> List[VideoInput]:
|
| 86 |
+
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
| 87 |
+
return videos
|
| 88 |
+
|
| 89 |
+
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
| 90 |
+
if isinstance(videos[0], Image.Image):
|
| 91 |
+
return [videos]
|
| 92 |
+
elif len(videos[0].shape) == 4:
|
| 93 |
+
return [list(video) for video in videos]
|
| 94 |
+
|
| 95 |
+
elif is_valid_image(videos) and len(videos.shape) == 4:
|
| 96 |
+
return [list(videos)]
|
| 97 |
+
|
| 98 |
+
raise ValueError(f"Could not make batched video from {videos}")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def smart_resize(
|
| 102 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
|
| 103 |
+
):
|
| 104 |
+
"""Rescales the image so that the following conditions are met:
|
| 105 |
+
|
| 106 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 107 |
+
|
| 108 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 109 |
+
|
| 110 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
if height < factor or width < factor:
|
| 114 |
+
# print("height, width", height, width)
|
| 115 |
+
if height < width:
|
| 116 |
+
h_bar = factor
|
| 117 |
+
w_bar = round(width / height * factor)
|
| 118 |
+
else:
|
| 119 |
+
h_bar = round(height / width * factor)
|
| 120 |
+
w_bar = factor
|
| 121 |
+
# print("h_bar, w_bar", h_bar, w_bar)
|
| 122 |
+
height, width = h_bar, w_bar
|
| 123 |
+
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
| 124 |
+
elif max(height, width) / min(height, width) > 200:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
| 127 |
+
)
|
| 128 |
+
h_bar = round(height / factor) * factor
|
| 129 |
+
w_bar = round(width / factor) * factor
|
| 130 |
+
if h_bar * w_bar > max_pixels:
|
| 131 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 132 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 133 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 134 |
+
elif h_bar * w_bar < min_pixels:
|
| 135 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 136 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 137 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 138 |
+
return h_bar, w_bar
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Qwen2VLImageProcessor(BaseImageProcessor):
|
| 142 |
+
r"""
|
| 143 |
+
Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 147 |
+
Whether to resize the image's (height, width) dimensions.
|
| 148 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 149 |
+
Resampling filter to use when resizing the image.
|
| 150 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 151 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
| 152 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 153 |
+
Scale factor to use if rescaling the image.
|
| 154 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 155 |
+
Whether to normalize the image.
|
| 156 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
| 157 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 158 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
| 159 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 160 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 161 |
+
Whether to convert the image to RGB.
|
| 162 |
+
min_pixels (`int`, *optional*, defaults to `56 * 56`):
|
| 163 |
+
The min pixels of the image to resize the image.
|
| 164 |
+
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
| 165 |
+
The max pixels of the image to resize the image.
|
| 166 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 167 |
+
The spacial patch size of the vision encoder.
|
| 168 |
+
temporal_patch_size (`int`, *optional*, defaults to 2):
|
| 169 |
+
The temporal patch size of the vision encoder.
|
| 170 |
+
merge_size (`int`, *optional*, defaults to 2):
|
| 171 |
+
The merge size of the vision encoder to llm encoder.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
| 175 |
+
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
do_resize: bool = True,
|
| 179 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 180 |
+
do_rescale: bool = True,
|
| 181 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 182 |
+
do_normalize: bool = True,
|
| 183 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 184 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 185 |
+
do_convert_rgb: bool = True,
|
| 186 |
+
min_pixels: int = 56 * 56,
|
| 187 |
+
max_pixels: int = 28 * 28 * 1280,
|
| 188 |
+
patch_size: int = 14,
|
| 189 |
+
temporal_patch_size: int = 2,
|
| 190 |
+
merge_size: int = 2,
|
| 191 |
+
shifted_patch_tokenize=False,
|
| 192 |
+
**kwargs,
|
| 193 |
+
) -> None:
|
| 194 |
+
super().__init__(**kwargs)
|
| 195 |
+
self.do_resize = do_resize
|
| 196 |
+
self.resample = resample
|
| 197 |
+
self.do_rescale = do_rescale
|
| 198 |
+
self.rescale_factor = rescale_factor
|
| 199 |
+
self.do_normalize = do_normalize
|
| 200 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
| 201 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
| 202 |
+
self.min_pixels = min_pixels
|
| 203 |
+
self.max_pixels = max_pixels
|
| 204 |
+
self.patch_size = patch_size
|
| 205 |
+
self.temporal_patch_size = temporal_patch_size
|
| 206 |
+
self.merge_size = merge_size
|
| 207 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
| 208 |
+
self.do_convert_rgb = do_convert_rgb
|
| 209 |
+
self.shifted_patch_tokenize = shifted_patch_tokenize
|
| 210 |
+
|
| 211 |
+
def _preprocess(
|
| 212 |
+
self,
|
| 213 |
+
images: Union[ImageInput, VideoInput],
|
| 214 |
+
do_resize: bool = None,
|
| 215 |
+
resample: PILImageResampling = None,
|
| 216 |
+
do_rescale: bool = None,
|
| 217 |
+
rescale_factor: float = None,
|
| 218 |
+
do_normalize: bool = None,
|
| 219 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 220 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 221 |
+
do_convert_rgb: bool = None,
|
| 222 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 223 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 224 |
+
):
|
| 225 |
+
"""
|
| 226 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
images (`ImageInput`):
|
| 230 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
| 231 |
+
vision_info (`List[Dict]`, *optional*):
|
| 232 |
+
Optional list of dictionaries containing additional information about vision inputs.
|
| 233 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 234 |
+
Whether to resize the image.
|
| 235 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 236 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
| 237 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 238 |
+
Whether to rescale the image.
|
| 239 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 240 |
+
Scale factor to use if rescaling the image.
|
| 241 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 242 |
+
Whether to normalize the image.
|
| 243 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 244 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 245 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 246 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 247 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 248 |
+
Whether to convert the image to RGB.
|
| 249 |
+
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 250 |
+
The channel dimension format for the output image. Can be one of:
|
| 251 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 252 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 253 |
+
- Unset: Use the channel dimension format of the input image.
|
| 254 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 255 |
+
The channel dimension format for the input image. Can be one of:
|
| 256 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 257 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 258 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 259 |
+
"""
|
| 260 |
+
# import pdb; pdb.set_trace()
|
| 261 |
+
# print("images", images)
|
| 262 |
+
# for image in images:
|
| 263 |
+
# print("image", image.size)
|
| 264 |
+
images = make_list_of_images(images)
|
| 265 |
+
|
| 266 |
+
if do_convert_rgb:
|
| 267 |
+
images = [convert_to_rgb(image) for image in images]
|
| 268 |
+
|
| 269 |
+
# All transformations expect numpy arrays.
|
| 270 |
+
images = [to_numpy_array(image) for image in images]
|
| 271 |
+
|
| 272 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 273 |
+
logger.warning_once(
|
| 274 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 275 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 276 |
+
)
|
| 277 |
+
if input_data_format is None:
|
| 278 |
+
# We assume that all images have the same channel dimension format.
|
| 279 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 280 |
+
|
| 281 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
| 282 |
+
resized_height, resized_width = height, width
|
| 283 |
+
processed_images = []
|
| 284 |
+
for image in images:
|
| 285 |
+
if do_resize:
|
| 286 |
+
resized_height, resized_width = smart_resize(
|
| 287 |
+
height,
|
| 288 |
+
width,
|
| 289 |
+
factor=self.patch_size * self.merge_size,
|
| 290 |
+
min_pixels=self.min_pixels,
|
| 291 |
+
max_pixels=self.max_pixels,
|
| 292 |
+
)
|
| 293 |
+
image = resize(
|
| 294 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if do_rescale:
|
| 298 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
| 299 |
+
|
| 300 |
+
if do_normalize:
|
| 301 |
+
image = self.normalize(
|
| 302 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 306 |
+
processed_images.append(image)
|
| 307 |
+
|
| 308 |
+
patches = np.array(processed_images)
|
| 309 |
+
if data_format == ChannelDimension.LAST:
|
| 310 |
+
patches = patches.transpose(0, 3, 1, 2)
|
| 311 |
+
if patches.shape[0] == 1:
|
| 312 |
+
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
| 313 |
+
channel = patches.shape[1]
|
| 314 |
+
grid_t = patches.shape[0] // self.temporal_patch_size
|
| 315 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
| 316 |
+
patches = patches.reshape(
|
| 317 |
+
grid_t,
|
| 318 |
+
self.temporal_patch_size,
|
| 319 |
+
channel,
|
| 320 |
+
grid_h // self.merge_size,
|
| 321 |
+
self.merge_size,
|
| 322 |
+
self.patch_size,
|
| 323 |
+
grid_w // self.merge_size,
|
| 324 |
+
self.merge_size,
|
| 325 |
+
self.patch_size,
|
| 326 |
+
)
|
| 327 |
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
| 328 |
+
flatten_patches = patches.reshape(
|
| 329 |
+
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
return flatten_patches, (grid_t, grid_h, grid_w)
|
| 333 |
+
|
| 334 |
+
def preprocess(
|
| 335 |
+
self,
|
| 336 |
+
images: ImageInput,
|
| 337 |
+
videos: VideoInput = None,
|
| 338 |
+
do_resize: bool = None,
|
| 339 |
+
size: Dict[str, int] = None,
|
| 340 |
+
resample: PILImageResampling = None,
|
| 341 |
+
do_rescale: bool = None,
|
| 342 |
+
rescale_factor: float = None,
|
| 343 |
+
do_normalize: bool = None,
|
| 344 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 345 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 346 |
+
do_convert_rgb: bool = None,
|
| 347 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 348 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 349 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 350 |
+
):
|
| 351 |
+
"""
|
| 352 |
+
Args:
|
| 353 |
+
images (`ImageInput`):
|
| 354 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 355 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 356 |
+
videos (`VideoInput`):
|
| 357 |
+
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
|
| 358 |
+
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
|
| 359 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 360 |
+
Whether to resize the image.
|
| 361 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 362 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
| 363 |
+
the longest edge resized to keep the input aspect ratio.
|
| 364 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 365 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 366 |
+
has an effect if `do_resize` is set to `True`.
|
| 367 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 368 |
+
Whether to rescale the image.
|
| 369 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 370 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 371 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 372 |
+
Whether to normalize the image.
|
| 373 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 374 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 375 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 376 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 377 |
+
`True`.
|
| 378 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 379 |
+
Whether to convert the image to RGB.
|
| 380 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 381 |
+
The type of tensors to return. Can be one of:
|
| 382 |
+
- Unset: Return a list of `np.ndarray`.
|
| 383 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 384 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 385 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 386 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 387 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 388 |
+
The channel dimension format for the output image. Can be one of:
|
| 389 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 390 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 391 |
+
- Unset: Use the channel dimension format of the input image.
|
| 392 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 393 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 394 |
+
from the input image. Can be one of:
|
| 395 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 396 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 397 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 398 |
+
|
| 399 |
+
"""
|
| 400 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 401 |
+
size = size if size is not None else self.size
|
| 402 |
+
resample = resample if resample is not None else self.resample
|
| 403 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 404 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 405 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 406 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 407 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 408 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 409 |
+
|
| 410 |
+
if images is not None:
|
| 411 |
+
images = make_batched_images(images)
|
| 412 |
+
if videos is not None:
|
| 413 |
+
videos = make_batched_videos(videos)
|
| 414 |
+
|
| 415 |
+
if images is not None and not valid_images(images):
|
| 416 |
+
raise ValueError(
|
| 417 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 418 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
validate_preprocess_arguments(
|
| 422 |
+
rescale_factor=rescale_factor,
|
| 423 |
+
do_normalize=do_normalize,
|
| 424 |
+
image_mean=image_mean,
|
| 425 |
+
image_std=image_std,
|
| 426 |
+
do_resize=do_resize,
|
| 427 |
+
size=size,
|
| 428 |
+
resample=resample,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
if images is not None:
|
| 432 |
+
pixel_values, vision_grid_thws = [], []
|
| 433 |
+
for image in images:
|
| 434 |
+
patches, image_grid_thw = self._preprocess(
|
| 435 |
+
image,
|
| 436 |
+
do_resize=do_resize,
|
| 437 |
+
resample=resample,
|
| 438 |
+
do_rescale=do_rescale,
|
| 439 |
+
rescale_factor=rescale_factor,
|
| 440 |
+
do_normalize=do_normalize,
|
| 441 |
+
image_mean=image_mean,
|
| 442 |
+
image_std=image_std,
|
| 443 |
+
data_format=data_format,
|
| 444 |
+
do_convert_rgb=do_convert_rgb,
|
| 445 |
+
input_data_format=input_data_format,
|
| 446 |
+
)
|
| 447 |
+
pixel_values.extend(patches)
|
| 448 |
+
vision_grid_thws.append(image_grid_thw)
|
| 449 |
+
pixel_values = np.array(pixel_values)
|
| 450 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
| 451 |
+
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
| 452 |
+
|
| 453 |
+
if videos is not None:
|
| 454 |
+
pixel_values, vision_grid_thws = [], []
|
| 455 |
+
for images in videos:
|
| 456 |
+
patches, video_grid_thw = self._preprocess(
|
| 457 |
+
images,
|
| 458 |
+
do_resize=do_resize,
|
| 459 |
+
resample=resample,
|
| 460 |
+
do_rescale=do_rescale,
|
| 461 |
+
rescale_factor=rescale_factor,
|
| 462 |
+
do_normalize=do_normalize,
|
| 463 |
+
image_mean=image_mean,
|
| 464 |
+
image_std=image_std,
|
| 465 |
+
data_format=data_format,
|
| 466 |
+
do_convert_rgb=do_convert_rgb,
|
| 467 |
+
input_data_format=input_data_format,
|
| 468 |
+
)
|
| 469 |
+
pixel_values.extend(patches)
|
| 470 |
+
vision_grid_thws.append(video_grid_thw)
|
| 471 |
+
pixel_values = np.array(pixel_values)
|
| 472 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
| 473 |
+
data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
|
| 474 |
+
|
| 475 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 476 |
+
|
image_utils.py
ADDED
|
@@ -0,0 +1,812 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import base64
|
| 17 |
+
import os
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import requests
|
| 23 |
+
from packaging import version
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from transformers.utils import (
|
| 27 |
+
ExplicitEnum,
|
| 28 |
+
is_jax_tensor,
|
| 29 |
+
is_numpy_array,
|
| 30 |
+
is_tf_tensor,
|
| 31 |
+
is_torch_available,
|
| 32 |
+
is_torch_tensor,
|
| 33 |
+
is_torchvision_available,
|
| 34 |
+
is_vision_available,
|
| 35 |
+
logging,
|
| 36 |
+
requires_backends,
|
| 37 |
+
to_numpy,
|
| 38 |
+
)
|
| 39 |
+
from transformers.utils.constants import ( # noqa: F401
|
| 40 |
+
IMAGENET_DEFAULT_MEAN,
|
| 41 |
+
IMAGENET_DEFAULT_STD,
|
| 42 |
+
IMAGENET_STANDARD_MEAN,
|
| 43 |
+
IMAGENET_STANDARD_STD,
|
| 44 |
+
OPENAI_CLIP_MEAN,
|
| 45 |
+
OPENAI_CLIP_STD,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_vision_available():
|
| 50 |
+
import PIL.Image
|
| 51 |
+
import PIL.ImageOps
|
| 52 |
+
|
| 53 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
| 54 |
+
PILImageResampling = PIL.Image.Resampling
|
| 55 |
+
else:
|
| 56 |
+
PILImageResampling = PIL.Image
|
| 57 |
+
|
| 58 |
+
if is_torchvision_available():
|
| 59 |
+
from torchvision.transforms import InterpolationMode
|
| 60 |
+
|
| 61 |
+
pil_torch_interpolation_mapping = {
|
| 62 |
+
PILImageResampling.NEAREST: InterpolationMode.NEAREST,
|
| 63 |
+
PILImageResampling.BOX: InterpolationMode.BOX,
|
| 64 |
+
PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
|
| 65 |
+
PILImageResampling.HAMMING: InterpolationMode.HAMMING,
|
| 66 |
+
PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
|
| 67 |
+
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if TYPE_CHECKING:
|
| 72 |
+
if is_torch_available():
|
| 73 |
+
import torch
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
logger = logging.get_logger(__name__)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
ImageInput = Union[
|
| 80 |
+
"PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
|
| 81 |
+
] # noqa
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
VideoInput = Union[
|
| 85 |
+
List["PIL.Image.Image"],
|
| 86 |
+
"np.ndarray",
|
| 87 |
+
"torch.Tensor",
|
| 88 |
+
List["np.ndarray"],
|
| 89 |
+
List["torch.Tensor"],
|
| 90 |
+
List[List["PIL.Image.Image"]],
|
| 91 |
+
List[List["np.ndarrray"]],
|
| 92 |
+
List[List["torch.Tensor"]],
|
| 93 |
+
] # noqa
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ChannelDimension(ExplicitEnum):
|
| 97 |
+
FIRST = "channels_first"
|
| 98 |
+
LAST = "channels_last"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class AnnotationFormat(ExplicitEnum):
|
| 102 |
+
COCO_DETECTION = "coco_detection"
|
| 103 |
+
COCO_PANOPTIC = "coco_panoptic"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class AnnotionFormat(ExplicitEnum):
|
| 107 |
+
COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
|
| 108 |
+
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def is_pil_image(img):
|
| 115 |
+
return is_vision_available() and isinstance(img, PIL.Image.Image)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ImageType(ExplicitEnum):
|
| 119 |
+
PIL = "pillow"
|
| 120 |
+
TORCH = "torch"
|
| 121 |
+
NUMPY = "numpy"
|
| 122 |
+
TENSORFLOW = "tensorflow"
|
| 123 |
+
JAX = "jax"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_image_type(image):
|
| 127 |
+
if is_pil_image(image):
|
| 128 |
+
return ImageType.PIL
|
| 129 |
+
if is_torch_tensor(image):
|
| 130 |
+
return ImageType.TORCH
|
| 131 |
+
if is_numpy_array(image):
|
| 132 |
+
return ImageType.NUMPY
|
| 133 |
+
if is_tf_tensor(image):
|
| 134 |
+
return ImageType.TENSORFLOW
|
| 135 |
+
if is_jax_tensor(image):
|
| 136 |
+
return ImageType.JAX
|
| 137 |
+
raise ValueError(f"Unrecognised image type {type(image)}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def is_valid_image(img):
|
| 141 |
+
return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def valid_images(imgs):
|
| 145 |
+
# If we have an list of images, make sure every image is valid
|
| 146 |
+
if isinstance(imgs, (list, tuple)):
|
| 147 |
+
for img in imgs:
|
| 148 |
+
if not valid_images(img):
|
| 149 |
+
return False
|
| 150 |
+
# If not a list of tuple, we have been given a single image or batched tensor of images
|
| 151 |
+
elif not is_valid_image(imgs):
|
| 152 |
+
return False
|
| 153 |
+
return True
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def is_batched(img):
|
| 157 |
+
if isinstance(img, (list, tuple)):
|
| 158 |
+
return is_valid_image(img[0])
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def is_scaled_image(image: np.ndarray) -> bool:
|
| 163 |
+
"""
|
| 164 |
+
Checks to see whether the pixel values have already been rescaled to [0, 1].
|
| 165 |
+
"""
|
| 166 |
+
if image.dtype == np.uint8:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
# It's possible the image has pixel values in [0, 255] but is of floating type
|
| 170 |
+
return np.min(image) >= 0 and np.max(image) <= 1
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
|
| 174 |
+
"""
|
| 175 |
+
Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
|
| 176 |
+
If the input is a batch of images, it is converted to a list of images.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
images (`ImageInput`):
|
| 180 |
+
Image of images to turn into a list of images.
|
| 181 |
+
expected_ndims (`int`, *optional*, defaults to 3):
|
| 182 |
+
Expected number of dimensions for a single input image. If the input image has a different number of
|
| 183 |
+
dimensions, an error is raised.
|
| 184 |
+
"""
|
| 185 |
+
if is_batched(images):
|
| 186 |
+
return images
|
| 187 |
+
|
| 188 |
+
# Either the input is a single image, in which case we create a list of length 1
|
| 189 |
+
if isinstance(images, PIL.Image.Image):
|
| 190 |
+
# PIL images are never batched
|
| 191 |
+
return [images]
|
| 192 |
+
|
| 193 |
+
if is_valid_image(images):
|
| 194 |
+
if images.ndim == expected_ndims + 1:
|
| 195 |
+
# Batch of images
|
| 196 |
+
images = list(images)
|
| 197 |
+
elif images.ndim == expected_ndims:
|
| 198 |
+
# Single image
|
| 199 |
+
images = [images]
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
|
| 203 |
+
f" {images.ndim} dimensions."
|
| 204 |
+
)
|
| 205 |
+
return images
|
| 206 |
+
raise ValueError(
|
| 207 |
+
"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
|
| 208 |
+
f"jax.ndarray, but got {type(images)}."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def to_numpy_array(img) -> np.ndarray:
|
| 213 |
+
if not is_valid_image(img):
|
| 214 |
+
raise ValueError(f"Invalid image type: {type(img)}")
|
| 215 |
+
|
| 216 |
+
if is_vision_available() and isinstance(img, PIL.Image.Image):
|
| 217 |
+
return np.array(img)
|
| 218 |
+
return to_numpy(img)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def infer_channel_dimension_format(
|
| 222 |
+
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
|
| 223 |
+
) -> ChannelDimension:
|
| 224 |
+
"""
|
| 225 |
+
Infers the channel dimension format of `image`.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
image (`np.ndarray`):
|
| 229 |
+
The image to infer the channel dimension of.
|
| 230 |
+
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
|
| 231 |
+
The number of channels of the image.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
The channel dimension of the image.
|
| 235 |
+
"""
|
| 236 |
+
num_channels = num_channels if num_channels is not None else (1, 3)
|
| 237 |
+
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
|
| 238 |
+
|
| 239 |
+
if image.ndim == 3:
|
| 240 |
+
first_dim, last_dim = 0, 2
|
| 241 |
+
elif image.ndim == 4:
|
| 242 |
+
first_dim, last_dim = 1, 3
|
| 243 |
+
else:
|
| 244 |
+
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
|
| 245 |
+
|
| 246 |
+
if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
|
| 247 |
+
logger.warning(
|
| 248 |
+
f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
|
| 249 |
+
)
|
| 250 |
+
return ChannelDimension.FIRST
|
| 251 |
+
elif image.shape[first_dim] in num_channels:
|
| 252 |
+
return ChannelDimension.FIRST
|
| 253 |
+
elif image.shape[last_dim] in num_channels:
|
| 254 |
+
return ChannelDimension.LAST
|
| 255 |
+
raise ValueError("Unable to infer channel dimension format")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def get_channel_dimension_axis(
|
| 259 |
+
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
|
| 260 |
+
) -> int:
|
| 261 |
+
"""
|
| 262 |
+
Returns the channel dimension axis of the image.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
image (`np.ndarray`):
|
| 266 |
+
The image to get the channel dimension axis of.
|
| 267 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 268 |
+
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
The channel dimension axis of the image.
|
| 272 |
+
"""
|
| 273 |
+
if input_data_format is None:
|
| 274 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 275 |
+
if input_data_format == ChannelDimension.FIRST:
|
| 276 |
+
return image.ndim - 3
|
| 277 |
+
elif input_data_format == ChannelDimension.LAST:
|
| 278 |
+
return image.ndim - 1
|
| 279 |
+
raise ValueError(f"Unsupported data format: {input_data_format}")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
|
| 283 |
+
"""
|
| 284 |
+
Returns the (height, width) dimensions of the image.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
image (`np.ndarray`):
|
| 288 |
+
The image to get the dimensions of.
|
| 289 |
+
channel_dim (`ChannelDimension`, *optional*):
|
| 290 |
+
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
A tuple of the image's height and width.
|
| 294 |
+
"""
|
| 295 |
+
if channel_dim is None:
|
| 296 |
+
channel_dim = infer_channel_dimension_format(image)
|
| 297 |
+
|
| 298 |
+
if channel_dim == ChannelDimension.FIRST:
|
| 299 |
+
return image.shape[-2], image.shape[-1]
|
| 300 |
+
elif channel_dim == ChannelDimension.LAST:
|
| 301 |
+
return image.shape[-3], image.shape[-2]
|
| 302 |
+
else:
|
| 303 |
+
raise ValueError(f"Unsupported data format: {channel_dim}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
|
| 307 |
+
if (
|
| 308 |
+
isinstance(annotation, dict)
|
| 309 |
+
and "image_id" in annotation
|
| 310 |
+
and "annotations" in annotation
|
| 311 |
+
and isinstance(annotation["annotations"], (list, tuple))
|
| 312 |
+
and (
|
| 313 |
+
# an image can have no annotations
|
| 314 |
+
len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
|
| 315 |
+
)
|
| 316 |
+
):
|
| 317 |
+
return True
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
|
| 322 |
+
if (
|
| 323 |
+
isinstance(annotation, dict)
|
| 324 |
+
and "image_id" in annotation
|
| 325 |
+
and "segments_info" in annotation
|
| 326 |
+
and "file_name" in annotation
|
| 327 |
+
and isinstance(annotation["segments_info"], (list, tuple))
|
| 328 |
+
and (
|
| 329 |
+
# an image can have no segments
|
| 330 |
+
len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
|
| 331 |
+
)
|
| 332 |
+
):
|
| 333 |
+
return True
|
| 334 |
+
return False
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
|
| 338 |
+
return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
|
| 342 |
+
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
|
| 346 |
+
"""
|
| 347 |
+
Loads `image` to a PIL Image.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
image (`str` or `PIL.Image.Image`):
|
| 351 |
+
The image to convert to the PIL Image format.
|
| 352 |
+
timeout (`float`, *optional*):
|
| 353 |
+
The timeout value in seconds for the URL request.
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
`PIL.Image.Image`: A PIL Image.
|
| 357 |
+
"""
|
| 358 |
+
requires_backends(load_image, ["vision"])
|
| 359 |
+
if isinstance(image, str):
|
| 360 |
+
if image.startswith("http://") or image.startswith("https://"):
|
| 361 |
+
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
| 362 |
+
# like http_huggingface_co.png
|
| 363 |
+
image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
|
| 364 |
+
elif os.path.isfile(image):
|
| 365 |
+
image = PIL.Image.open(image)
|
| 366 |
+
else:
|
| 367 |
+
if image.startswith("data:image/"):
|
| 368 |
+
image = image.split(",")[1]
|
| 369 |
+
|
| 370 |
+
# Try to load as base64
|
| 371 |
+
try:
|
| 372 |
+
b64 = base64.decodebytes(image.encode())
|
| 373 |
+
image = PIL.Image.open(BytesIO(b64))
|
| 374 |
+
except Exception as e:
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
|
| 377 |
+
)
|
| 378 |
+
elif isinstance(image, PIL.Image.Image):
|
| 379 |
+
image = image
|
| 380 |
+
else:
|
| 381 |
+
raise TypeError(
|
| 382 |
+
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
|
| 383 |
+
)
|
| 384 |
+
image = PIL.ImageOps.exif_transpose(image)
|
| 385 |
+
image = image.convert("RGB")
|
| 386 |
+
return image
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def validate_preprocess_arguments(
|
| 390 |
+
do_rescale: Optional[bool] = None,
|
| 391 |
+
rescale_factor: Optional[float] = None,
|
| 392 |
+
do_normalize: Optional[bool] = None,
|
| 393 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 394 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 395 |
+
do_pad: Optional[bool] = None,
|
| 396 |
+
size_divisibility: Optional[int] = None,
|
| 397 |
+
do_center_crop: Optional[bool] = None,
|
| 398 |
+
crop_size: Optional[Dict[str, int]] = None,
|
| 399 |
+
do_resize: Optional[bool] = None,
|
| 400 |
+
size: Optional[Dict[str, int]] = None,
|
| 401 |
+
resample: Optional["PILImageResampling"] = None,
|
| 402 |
+
):
|
| 403 |
+
"""
|
| 404 |
+
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
|
| 405 |
+
Raises `ValueError` if arguments incompatibility is caught.
|
| 406 |
+
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
|
| 407 |
+
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
|
| 408 |
+
existing arguments when possible.
|
| 409 |
+
|
| 410 |
+
"""
|
| 411 |
+
if do_rescale and rescale_factor is None:
|
| 412 |
+
raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
|
| 413 |
+
|
| 414 |
+
if do_pad and size_divisibility is None:
|
| 415 |
+
# Here, size_divisor might be passed as the value of size
|
| 416 |
+
raise ValueError(
|
| 417 |
+
"Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if do_normalize and (image_mean is None or image_std is None):
|
| 421 |
+
raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
|
| 422 |
+
|
| 423 |
+
if do_center_crop and crop_size is None:
|
| 424 |
+
raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
|
| 425 |
+
|
| 426 |
+
if do_resize and (size is None or resample is None):
|
| 427 |
+
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# In the future we can add a TF implementation here when we have TF models.
|
| 431 |
+
class ImageFeatureExtractionMixin:
|
| 432 |
+
"""
|
| 433 |
+
Mixin that contain utilities for preparing image features.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
def _ensure_format_supported(self, image):
|
| 437 |
+
if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
|
| 438 |
+
raise ValueError(
|
| 439 |
+
f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
|
| 440 |
+
"`torch.Tensor` are."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
def to_pil_image(self, image, rescale=None):
|
| 444 |
+
"""
|
| 445 |
+
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
| 446 |
+
needed.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
|
| 450 |
+
The image to convert to the PIL Image format.
|
| 451 |
+
rescale (`bool`, *optional*):
|
| 452 |
+
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
|
| 453 |
+
default to `True` if the image type is a floating type, `False` otherwise.
|
| 454 |
+
"""
|
| 455 |
+
self._ensure_format_supported(image)
|
| 456 |
+
|
| 457 |
+
if is_torch_tensor(image):
|
| 458 |
+
image = image.numpy()
|
| 459 |
+
|
| 460 |
+
if isinstance(image, np.ndarray):
|
| 461 |
+
if rescale is None:
|
| 462 |
+
# rescale default to the array being of floating type.
|
| 463 |
+
rescale = isinstance(image.flat[0], np.floating)
|
| 464 |
+
# If the channel as been moved to first dim, we put it back at the end.
|
| 465 |
+
if image.ndim == 3 and image.shape[0] in [1, 3]:
|
| 466 |
+
image = image.transpose(1, 2, 0)
|
| 467 |
+
if rescale:
|
| 468 |
+
image = image * 255
|
| 469 |
+
image = image.astype(np.uint8)
|
| 470 |
+
return PIL.Image.fromarray(image)
|
| 471 |
+
return image
|
| 472 |
+
|
| 473 |
+
def convert_rgb(self, image):
|
| 474 |
+
"""
|
| 475 |
+
Converts `PIL.Image.Image` to RGB format.
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
image (`PIL.Image.Image`):
|
| 479 |
+
The image to convert.
|
| 480 |
+
"""
|
| 481 |
+
self._ensure_format_supported(image)
|
| 482 |
+
if not isinstance(image, PIL.Image.Image):
|
| 483 |
+
return image
|
| 484 |
+
|
| 485 |
+
return image.convert("RGB")
|
| 486 |
+
|
| 487 |
+
def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
|
| 488 |
+
"""
|
| 489 |
+
Rescale a numpy image by scale amount
|
| 490 |
+
"""
|
| 491 |
+
self._ensure_format_supported(image)
|
| 492 |
+
return image * scale
|
| 493 |
+
|
| 494 |
+
def to_numpy_array(self, image, rescale=None, channel_first=True):
|
| 495 |
+
"""
|
| 496 |
+
Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
|
| 497 |
+
dimension.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 501 |
+
The image to convert to a NumPy array.
|
| 502 |
+
rescale (`bool`, *optional*):
|
| 503 |
+
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
|
| 504 |
+
default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
|
| 505 |
+
channel_first (`bool`, *optional*, defaults to `True`):
|
| 506 |
+
Whether or not to permute the dimensions of the image to put the channel dimension first.
|
| 507 |
+
"""
|
| 508 |
+
self._ensure_format_supported(image)
|
| 509 |
+
|
| 510 |
+
if isinstance(image, PIL.Image.Image):
|
| 511 |
+
image = np.array(image)
|
| 512 |
+
|
| 513 |
+
if is_torch_tensor(image):
|
| 514 |
+
image = image.numpy()
|
| 515 |
+
|
| 516 |
+
rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
|
| 517 |
+
|
| 518 |
+
if rescale:
|
| 519 |
+
image = self.rescale(image.astype(np.float32), 1 / 255.0)
|
| 520 |
+
|
| 521 |
+
if channel_first and image.ndim == 3:
|
| 522 |
+
image = image.transpose(2, 0, 1)
|
| 523 |
+
|
| 524 |
+
return image
|
| 525 |
+
|
| 526 |
+
def expand_dims(self, image):
|
| 527 |
+
"""
|
| 528 |
+
Expands 2-dimensional `image` to 3 dimensions.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 532 |
+
The image to expand.
|
| 533 |
+
"""
|
| 534 |
+
self._ensure_format_supported(image)
|
| 535 |
+
|
| 536 |
+
# Do nothing if PIL image
|
| 537 |
+
if isinstance(image, PIL.Image.Image):
|
| 538 |
+
return image
|
| 539 |
+
|
| 540 |
+
if is_torch_tensor(image):
|
| 541 |
+
image = image.unsqueeze(0)
|
| 542 |
+
else:
|
| 543 |
+
image = np.expand_dims(image, axis=0)
|
| 544 |
+
return image
|
| 545 |
+
|
| 546 |
+
def normalize(self, image, mean, std, rescale=False):
|
| 547 |
+
"""
|
| 548 |
+
Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
|
| 549 |
+
if it's a PIL Image.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 553 |
+
The image to normalize.
|
| 554 |
+
mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
|
| 555 |
+
The mean (per channel) to use for normalization.
|
| 556 |
+
std (`List[float]` or `np.ndarray` or `torch.Tensor`):
|
| 557 |
+
The standard deviation (per channel) to use for normalization.
|
| 558 |
+
rescale (`bool`, *optional*, defaults to `False`):
|
| 559 |
+
Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
|
| 560 |
+
happen automatically.
|
| 561 |
+
"""
|
| 562 |
+
self._ensure_format_supported(image)
|
| 563 |
+
|
| 564 |
+
if isinstance(image, PIL.Image.Image):
|
| 565 |
+
image = self.to_numpy_array(image, rescale=True)
|
| 566 |
+
# If the input image is a PIL image, it automatically gets rescaled. If it's another
|
| 567 |
+
# type it may need rescaling.
|
| 568 |
+
elif rescale:
|
| 569 |
+
if isinstance(image, np.ndarray):
|
| 570 |
+
image = self.rescale(image.astype(np.float32), 1 / 255.0)
|
| 571 |
+
elif is_torch_tensor(image):
|
| 572 |
+
image = self.rescale(image.float(), 1 / 255.0)
|
| 573 |
+
|
| 574 |
+
if isinstance(image, np.ndarray):
|
| 575 |
+
if not isinstance(mean, np.ndarray):
|
| 576 |
+
mean = np.array(mean).astype(image.dtype)
|
| 577 |
+
if not isinstance(std, np.ndarray):
|
| 578 |
+
std = np.array(std).astype(image.dtype)
|
| 579 |
+
elif is_torch_tensor(image):
|
| 580 |
+
import torch
|
| 581 |
+
|
| 582 |
+
if not isinstance(mean, torch.Tensor):
|
| 583 |
+
if isinstance(mean, np.ndarray):
|
| 584 |
+
mean = torch.from_numpy(mean)
|
| 585 |
+
else:
|
| 586 |
+
mean = torch.tensor(mean)
|
| 587 |
+
if not isinstance(std, torch.Tensor):
|
| 588 |
+
if isinstance(std, np.ndarray):
|
| 589 |
+
std = torch.from_numpy(std)
|
| 590 |
+
else:
|
| 591 |
+
std = torch.tensor(std)
|
| 592 |
+
|
| 593 |
+
if image.ndim == 3 and image.shape[0] in [1, 3]:
|
| 594 |
+
return (image - mean[:, None, None]) / std[:, None, None]
|
| 595 |
+
else:
|
| 596 |
+
return (image - mean) / std
|
| 597 |
+
|
| 598 |
+
def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
|
| 599 |
+
"""
|
| 600 |
+
Resizes `image`. Enforces conversion of input to PIL.Image.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 604 |
+
The image to resize.
|
| 605 |
+
size (`int` or `Tuple[int, int]`):
|
| 606 |
+
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
|
| 607 |
+
matched to this.
|
| 608 |
+
|
| 609 |
+
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
|
| 610 |
+
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
|
| 611 |
+
this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
| 612 |
+
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 613 |
+
The filter to user for resampling.
|
| 614 |
+
default_to_square (`bool`, *optional*, defaults to `True`):
|
| 615 |
+
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
|
| 616 |
+
square (`size`,`size`). If set to `False`, will replicate
|
| 617 |
+
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
|
| 618 |
+
with support for resizing only the smallest edge and providing an optional `max_size`.
|
| 619 |
+
max_size (`int`, *optional*, defaults to `None`):
|
| 620 |
+
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
|
| 621 |
+
greater than `max_size` after being resized according to `size`, then the image is resized again so
|
| 622 |
+
that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
|
| 623 |
+
edge may be shorter than `size`. Only used if `default_to_square` is `False`.
|
| 624 |
+
|
| 625 |
+
Returns:
|
| 626 |
+
image: A resized `PIL.Image.Image`.
|
| 627 |
+
"""
|
| 628 |
+
resample = resample if resample is not None else PILImageResampling.BILINEAR
|
| 629 |
+
|
| 630 |
+
self._ensure_format_supported(image)
|
| 631 |
+
|
| 632 |
+
if not isinstance(image, PIL.Image.Image):
|
| 633 |
+
image = self.to_pil_image(image)
|
| 634 |
+
|
| 635 |
+
if isinstance(size, list):
|
| 636 |
+
size = tuple(size)
|
| 637 |
+
|
| 638 |
+
if isinstance(size, int) or len(size) == 1:
|
| 639 |
+
if default_to_square:
|
| 640 |
+
size = (size, size) if isinstance(size, int) else (size[0], size[0])
|
| 641 |
+
else:
|
| 642 |
+
width, height = image.size
|
| 643 |
+
# specified size only for the smallest edge
|
| 644 |
+
short, long = (width, height) if width <= height else (height, width)
|
| 645 |
+
requested_new_short = size if isinstance(size, int) else size[0]
|
| 646 |
+
|
| 647 |
+
if short == requested_new_short:
|
| 648 |
+
return image
|
| 649 |
+
|
| 650 |
+
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
|
| 651 |
+
|
| 652 |
+
if max_size is not None:
|
| 653 |
+
if max_size <= requested_new_short:
|
| 654 |
+
raise ValueError(
|
| 655 |
+
f"max_size = {max_size} must be strictly greater than the requested "
|
| 656 |
+
f"size for the smaller edge size = {size}"
|
| 657 |
+
)
|
| 658 |
+
if new_long > max_size:
|
| 659 |
+
new_short, new_long = int(max_size * new_short / new_long), max_size
|
| 660 |
+
|
| 661 |
+
size = (new_short, new_long) if width <= height else (new_long, new_short)
|
| 662 |
+
|
| 663 |
+
return image.resize(size, resample=resample)
|
| 664 |
+
|
| 665 |
+
def center_crop(self, image, size):
|
| 666 |
+
"""
|
| 667 |
+
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
|
| 668 |
+
size given, it will be padded (so the returned result has the size asked).
|
| 669 |
+
|
| 670 |
+
Args:
|
| 671 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
|
| 672 |
+
The image to resize.
|
| 673 |
+
size (`int` or `Tuple[int, int]`):
|
| 674 |
+
The size to which crop the image.
|
| 675 |
+
|
| 676 |
+
Returns:
|
| 677 |
+
new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
|
| 678 |
+
height, width).
|
| 679 |
+
"""
|
| 680 |
+
self._ensure_format_supported(image)
|
| 681 |
+
|
| 682 |
+
if not isinstance(size, tuple):
|
| 683 |
+
size = (size, size)
|
| 684 |
+
|
| 685 |
+
# PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
|
| 686 |
+
if is_torch_tensor(image) or isinstance(image, np.ndarray):
|
| 687 |
+
if image.ndim == 2:
|
| 688 |
+
image = self.expand_dims(image)
|
| 689 |
+
image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
|
| 690 |
+
else:
|
| 691 |
+
image_shape = (image.size[1], image.size[0])
|
| 692 |
+
|
| 693 |
+
top = (image_shape[0] - size[0]) // 2
|
| 694 |
+
bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
|
| 695 |
+
left = (image_shape[1] - size[1]) // 2
|
| 696 |
+
right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
|
| 697 |
+
|
| 698 |
+
# For PIL Images we have a method to crop directly.
|
| 699 |
+
if isinstance(image, PIL.Image.Image):
|
| 700 |
+
return image.crop((left, top, right, bottom))
|
| 701 |
+
|
| 702 |
+
# Check if image is in (n_channels, height, width) or (height, width, n_channels) format
|
| 703 |
+
channel_first = True if image.shape[0] in [1, 3] else False
|
| 704 |
+
|
| 705 |
+
# Transpose (height, width, n_channels) format images
|
| 706 |
+
if not channel_first:
|
| 707 |
+
if isinstance(image, np.ndarray):
|
| 708 |
+
image = image.transpose(2, 0, 1)
|
| 709 |
+
if is_torch_tensor(image):
|
| 710 |
+
image = image.permute(2, 0, 1)
|
| 711 |
+
|
| 712 |
+
# Check if cropped area is within image boundaries
|
| 713 |
+
if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
|
| 714 |
+
return image[..., top:bottom, left:right]
|
| 715 |
+
|
| 716 |
+
# Otherwise, we may need to pad if the image is too small. Oh joy...
|
| 717 |
+
new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
|
| 718 |
+
if isinstance(image, np.ndarray):
|
| 719 |
+
new_image = np.zeros_like(image, shape=new_shape)
|
| 720 |
+
elif is_torch_tensor(image):
|
| 721 |
+
new_image = image.new_zeros(new_shape)
|
| 722 |
+
|
| 723 |
+
top_pad = (new_shape[-2] - image_shape[0]) // 2
|
| 724 |
+
bottom_pad = top_pad + image_shape[0]
|
| 725 |
+
left_pad = (new_shape[-1] - image_shape[1]) // 2
|
| 726 |
+
right_pad = left_pad + image_shape[1]
|
| 727 |
+
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
|
| 728 |
+
|
| 729 |
+
top += top_pad
|
| 730 |
+
bottom += top_pad
|
| 731 |
+
left += left_pad
|
| 732 |
+
right += left_pad
|
| 733 |
+
|
| 734 |
+
new_image = new_image[
|
| 735 |
+
..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
|
| 736 |
+
]
|
| 737 |
+
|
| 738 |
+
return new_image
|
| 739 |
+
|
| 740 |
+
def flip_channel_order(self, image):
|
| 741 |
+
"""
|
| 742 |
+
Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
|
| 743 |
+
`image` to a NumPy array if it's a PIL Image.
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 747 |
+
The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
|
| 748 |
+
be first.
|
| 749 |
+
"""
|
| 750 |
+
self._ensure_format_supported(image)
|
| 751 |
+
|
| 752 |
+
if isinstance(image, PIL.Image.Image):
|
| 753 |
+
image = self.to_numpy_array(image)
|
| 754 |
+
|
| 755 |
+
return image[::-1, :, :]
|
| 756 |
+
|
| 757 |
+
def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
|
| 758 |
+
"""
|
| 759 |
+
Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
|
| 760 |
+
counter clockwise around its centre.
|
| 761 |
+
|
| 762 |
+
Args:
|
| 763 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 764 |
+
The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
|
| 765 |
+
rotating.
|
| 766 |
+
|
| 767 |
+
Returns:
|
| 768 |
+
image: A rotated `PIL.Image.Image`.
|
| 769 |
+
"""
|
| 770 |
+
resample = resample if resample is not None else PIL.Image.NEAREST
|
| 771 |
+
|
| 772 |
+
self._ensure_format_supported(image)
|
| 773 |
+
|
| 774 |
+
if not isinstance(image, PIL.Image.Image):
|
| 775 |
+
image = self.to_pil_image(image)
|
| 776 |
+
|
| 777 |
+
return image.rotate(
|
| 778 |
+
angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def validate_annotations(
|
| 783 |
+
annotation_format: AnnotationFormat,
|
| 784 |
+
supported_annotation_formats: Tuple[AnnotationFormat, ...],
|
| 785 |
+
annotations: List[Dict],
|
| 786 |
+
) -> None:
|
| 787 |
+
if annotation_format not in supported_annotation_formats:
|
| 788 |
+
raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
|
| 789 |
+
|
| 790 |
+
if annotation_format is AnnotationFormat.COCO_DETECTION:
|
| 791 |
+
if not valid_coco_detection_annotations(annotations):
|
| 792 |
+
raise ValueError(
|
| 793 |
+
"Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
|
| 794 |
+
"(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
|
| 795 |
+
"being a list of annotations in the COCO format."
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
if annotation_format is AnnotationFormat.COCO_PANOPTIC:
|
| 799 |
+
if not valid_coco_panoptic_annotations(annotations):
|
| 800 |
+
raise ValueError(
|
| 801 |
+
"Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
|
| 802 |
+
"(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
|
| 803 |
+
"the latter being a list of annotations in the COCO format."
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
|
| 808 |
+
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
|
| 809 |
+
if unused_keys:
|
| 810 |
+
unused_key_str = ", ".join(unused_keys)
|
| 811 |
+
# TODO raise a warning here instead of simply logging?
|
| 812 |
+
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
|
modeling_dualvitok.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import math
|
| 6 |
+
from typing import Optional, Tuple, Union, List, Callable
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.nn import Module
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat, pack, unpack
|
| 14 |
+
from einx import get_at
|
| 15 |
+
|
| 16 |
+
from torch.utils.checkpoint import checkpoint
|
| 17 |
+
from transformers import AutoImageProcessor
|
| 18 |
+
from transformers.modeling_utils import PreTrainedModel, get_parameter_device, get_parameter_dtype
|
| 19 |
+
|
| 20 |
+
from .configuration_dualvitok import DualViTokConfig
|
| 21 |
+
from .modeling_movqgan import MoVQModel, MoVQEncoder, MoVQDecoder, Decoder
|
| 22 |
+
|
| 23 |
+
from .configuration_qwen2vit import Qwen2VLVisionConfig
|
| 24 |
+
from .modeling_qwen2vit import Qwen2VisionTransformerPretrainedModel, \
|
| 25 |
+
VisionRotaryEmbedding, Qwen2VLBatchVisionBlock
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import xformers.ops as xops
|
| 29 |
+
|
| 30 |
+
is_xformers_available = True
|
| 31 |
+
except Exception as e:
|
| 32 |
+
is_xformers_available = False
|
| 33 |
+
|
| 34 |
+
if torch.__version__ > "2.1.2":
|
| 35 |
+
IS_SDPA_AVAILABLE = True
|
| 36 |
+
else:
|
| 37 |
+
IS_SDPA_AVAILABLE = False
|
| 38 |
+
|
| 39 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 40 |
+
sys.path.append(cur_dir)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# helper functions
|
| 44 |
+
|
| 45 |
+
def exists(v):
|
| 46 |
+
return v is not None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def identity(t):
|
| 50 |
+
return t
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def default(v, d):
|
| 54 |
+
return v if exists(v) else d
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def pack_one(t, pattern):
|
| 58 |
+
packed, packed_shape = pack([t], pattern)
|
| 59 |
+
|
| 60 |
+
def inverse(out, inv_pattern=None):
|
| 61 |
+
inv_pattern = default(inv_pattern, pattern)
|
| 62 |
+
out, = unpack(out, packed_shape, inv_pattern)
|
| 63 |
+
return out
|
| 64 |
+
|
| 65 |
+
return packed, inverse
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# class
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SimVQ(Module):
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
dim,
|
| 75 |
+
codebook_size,
|
| 76 |
+
codebook_transform: Module | None = None,
|
| 77 |
+
init_fn: Callable = identity,
|
| 78 |
+
channel_first=True,
|
| 79 |
+
input_to_quantize_commit_loss_weight=0.25,
|
| 80 |
+
commitment_weight=1.,
|
| 81 |
+
frozen_codebook_dim=None # frozen codebook dim could have different dimensions than projection
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.codebook_size = codebook_size
|
| 85 |
+
self.channel_first = channel_first
|
| 86 |
+
|
| 87 |
+
frozen_codebook_dim = default(frozen_codebook_dim, dim)
|
| 88 |
+
codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
|
| 89 |
+
codebook = init_fn(codebook)
|
| 90 |
+
|
| 91 |
+
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
|
| 92 |
+
|
| 93 |
+
if not exists(codebook_transform):
|
| 94 |
+
codebook_transform = nn.Linear(frozen_codebook_dim, dim, bias=False)
|
| 95 |
+
|
| 96 |
+
self.code_transform = codebook_transform
|
| 97 |
+
|
| 98 |
+
self.register_buffer('frozen_codebook', codebook)
|
| 99 |
+
|
| 100 |
+
# commit loss weighting - weighing input to quantize a bit less is crucial for it to work
|
| 101 |
+
self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
|
| 102 |
+
|
| 103 |
+
# total commitment loss weight
|
| 104 |
+
self.commitment_weight = commitment_weight
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def codebook(self):
|
| 108 |
+
return self.code_transform(self.frozen_codebook)
|
| 109 |
+
|
| 110 |
+
def indices_to_codes(
|
| 111 |
+
self,
|
| 112 |
+
indices
|
| 113 |
+
):
|
| 114 |
+
implicit_codebook = self.codebook
|
| 115 |
+
|
| 116 |
+
frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
|
| 117 |
+
quantized = self.code_transform(frozen_codes)
|
| 118 |
+
|
| 119 |
+
if self.channel_first:
|
| 120 |
+
quantized = rearrange(quantized, 'b ... d -> b d ...')
|
| 121 |
+
|
| 122 |
+
return quantized
|
| 123 |
+
|
| 124 |
+
def forward(
|
| 125 |
+
self,
|
| 126 |
+
x
|
| 127 |
+
):
|
| 128 |
+
if self.channel_first:
|
| 129 |
+
x = rearrange(x, 'b d ... -> b ... d')
|
| 130 |
+
|
| 131 |
+
x, inverse_pack = pack_one(x, 'b * d')
|
| 132 |
+
|
| 133 |
+
implicit_codebook = self.codebook
|
| 134 |
+
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
dist = torch.cdist(x, implicit_codebook)
|
| 137 |
+
indices = dist.argmin(dim=-1)
|
| 138 |
+
|
| 139 |
+
# select codes
|
| 140 |
+
|
| 141 |
+
quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
|
| 142 |
+
|
| 143 |
+
# commit loss and straight through, as was done in the paper
|
| 144 |
+
|
| 145 |
+
commit_loss = (
|
| 146 |
+
F.mse_loss(x.detach(), quantized) +
|
| 147 |
+
F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
quantized = (quantized - x).detach() + x
|
| 151 |
+
|
| 152 |
+
quantized = inverse_pack(quantized)
|
| 153 |
+
indices = inverse_pack(indices, 'b *')
|
| 154 |
+
|
| 155 |
+
if self.channel_first:
|
| 156 |
+
quantized = rearrange(quantized, 'b ... d-> b d ...')
|
| 157 |
+
|
| 158 |
+
return quantized, commit_loss * self.commitment_weight, indices
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def init_weights(m):
|
| 162 |
+
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):
|
| 163 |
+
if m.weight is not None:
|
| 164 |
+
nn.init.constant_(m.weight, 1)
|
| 165 |
+
if m.bias is not None:
|
| 166 |
+
nn.init.constant_(m.bias, 0)
|
| 167 |
+
elif isinstance(m, nn.Linear):
|
| 168 |
+
nn.init.xavier_uniform_(m.weight)
|
| 169 |
+
if m.bias is not None:
|
| 170 |
+
nn.init.constant_(m.bias, 0)
|
| 171 |
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) \
|
| 172 |
+
or isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
|
| 173 |
+
w = m.weight.data
|
| 174 |
+
nn.init.xavier_uniform_(w)
|
| 175 |
+
if m.bias is not None:
|
| 176 |
+
nn.init.constant_(m.bias, 0)
|
| 177 |
+
elif isinstance(m, nn.Embedding):
|
| 178 |
+
nn.init.normal_(m.weight, mean=0, std=1)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class ScalingLayerForQwen2ViT:
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
min_pixels: int = 56 * 56,
|
| 185 |
+
max_pixels: int = 28 * 28 * 1280,
|
| 186 |
+
patch_size: int = 14,
|
| 187 |
+
temporal_patch_size: int = 2,
|
| 188 |
+
merge_size: int = 2,
|
| 189 |
+
**kwargs,
|
| 190 |
+
) -> None:
|
| 191 |
+
super().__init__(**kwargs)
|
| 192 |
+
OPENAI_CLIP_MEAN = torch.as_tensor([0.48145466, 0.4578275, 0.40821073])[None, :, None, None]
|
| 193 |
+
OPENAI_CLIP_STD = torch.as_tensor([0.26862954, 0.26130258, 0.27577711])[None, :, None, None]
|
| 194 |
+
|
| 195 |
+
self.image_mean = OPENAI_CLIP_MEAN
|
| 196 |
+
self.image_std = OPENAI_CLIP_STD
|
| 197 |
+
self.min_pixels = min_pixels
|
| 198 |
+
self.max_pixels = max_pixels
|
| 199 |
+
self.patch_size = patch_size
|
| 200 |
+
self.temporal_patch_size = temporal_patch_size
|
| 201 |
+
self.merge_size = merge_size
|
| 202 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
| 203 |
+
|
| 204 |
+
def __call__(self, images):
|
| 205 |
+
if images.ndim == 4:
|
| 206 |
+
images = images.unsqueeze(1)
|
| 207 |
+
batch_size, temporal, channel, height, width = images.shape
|
| 208 |
+
|
| 209 |
+
factor = self.patch_size * self.merge_size
|
| 210 |
+
|
| 211 |
+
resized_height, resized_width = height // factor * factor, width // factor * factor
|
| 212 |
+
|
| 213 |
+
images = (images + 1) / 2 # rescale to [0, 1.]
|
| 214 |
+
|
| 215 |
+
images = torch.nn.functional.interpolate(
|
| 216 |
+
images.flatten(0, 1).float(),
|
| 217 |
+
size=(resized_height, resized_width),
|
| 218 |
+
mode='bicubic',
|
| 219 |
+
align_corners=False,
|
| 220 |
+
antialias=True
|
| 221 |
+
).to(images.dtype)
|
| 222 |
+
|
| 223 |
+
images = images.clamp(0, 1) # rescale to [0, 1.]
|
| 224 |
+
images = ((images - self.image_mean.to(images)) / self.image_std.to(images))
|
| 225 |
+
|
| 226 |
+
images = rearrange(images, '(b t) c h w -> b t c h w', b=batch_size, t=temporal)
|
| 227 |
+
if temporal == 1:
|
| 228 |
+
images = images.repeat_interleave(self.temporal_patch_size, dim=1)
|
| 229 |
+
temporal = self.temporal_patch_size
|
| 230 |
+
|
| 231 |
+
grid_t = temporal // self.temporal_patch_size
|
| 232 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
| 233 |
+
|
| 234 |
+
images = images.reshape(
|
| 235 |
+
batch_size * grid_t,
|
| 236 |
+
self.temporal_patch_size,
|
| 237 |
+
channel,
|
| 238 |
+
-1
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
images = rearrange(images, 'b p c n -> b n (c p)')
|
| 242 |
+
images = images.reshape(
|
| 243 |
+
batch_size * grid_t,
|
| 244 |
+
grid_h // self.merge_size,
|
| 245 |
+
self.merge_size,
|
| 246 |
+
self.patch_size,
|
| 247 |
+
grid_w // self.merge_size,
|
| 248 |
+
self.merge_size,
|
| 249 |
+
self.patch_size,
|
| 250 |
+
-1
|
| 251 |
+
)
|
| 252 |
+
images = rearrange(images, 'b h k s1 w l s2 n -> (b h w k l) (n s1 s2)')
|
| 253 |
+
|
| 254 |
+
return dict(image=images, image_grid_thw=torch.as_tensor([[grid_t, grid_h, grid_w] for _ in range(batch_size)]))
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class SemanticEncoder(nn.Module):
|
| 258 |
+
def __init__(self,
|
| 259 |
+
semantic_encoder,
|
| 260 |
+
z_channels=4,
|
| 261 |
+
num_blocks=2,
|
| 262 |
+
embed_dim=1280,
|
| 263 |
+
proj_layer='linear',
|
| 264 |
+
attn_implementation='xformers',
|
| 265 |
+
target_mlp='identity',
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.embed_dim = embed_dim
|
| 269 |
+
|
| 270 |
+
if isinstance(semantic_encoder, str):
|
| 271 |
+
self.model = Qwen2VisionTransformerPretrainedModel.from_pretrained(
|
| 272 |
+
semantic_encoder,
|
| 273 |
+
attn_implementation=attn_implementation
|
| 274 |
+
)
|
| 275 |
+
elif isinstance(semantic_encoder, dict):
|
| 276 |
+
config = Qwen2VLVisionConfig(**semantic_encoder, attn_implementation=attn_implementation)
|
| 277 |
+
self.model = Qwen2VisionTransformerPretrainedModel(config)
|
| 278 |
+
else:
|
| 279 |
+
raise ValueError(f"Invalid semantic_encoder: {semantic_encoder}")
|
| 280 |
+
input_channels = self.model.config.hidden_size
|
| 281 |
+
|
| 282 |
+
for p in self.model.parameters():
|
| 283 |
+
p.requires_grad = False
|
| 284 |
+
|
| 285 |
+
self.proj_in = nn.Conv2d(input_channels, embed_dim, 1, 1) if input_channels != embed_dim else nn.Identity()
|
| 286 |
+
|
| 287 |
+
config = Qwen2VLVisionConfig(depth=num_blocks,
|
| 288 |
+
embed_dim=embed_dim, )
|
| 289 |
+
head_dim = config.embed_dim // config.num_heads
|
| 290 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
| 291 |
+
|
| 292 |
+
self.blocks = nn.ModuleList(
|
| 293 |
+
[Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if proj_layer == 'norm_linear':
|
| 297 |
+
self.proj_out = nn.Sequential(
|
| 298 |
+
nn.LayerNorm(embed_dim),
|
| 299 |
+
nn.Linear(
|
| 300 |
+
embed_dim,
|
| 301 |
+
z_channels,
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
elif proj_layer == 'linear':
|
| 305 |
+
self.proj_out = nn.Sequential(
|
| 306 |
+
nn.Linear(
|
| 307 |
+
embed_dim,
|
| 308 |
+
z_channels,
|
| 309 |
+
)
|
| 310 |
+
)
|
| 311 |
+
elif proj_layer == 'mlp':
|
| 312 |
+
self.proj_out = nn.Sequential(
|
| 313 |
+
nn.Linear(embed_dim, embed_dim),
|
| 314 |
+
nn.Tanh(),
|
| 315 |
+
nn.Linear(embed_dim, z_channels),
|
| 316 |
+
)
|
| 317 |
+
else:
|
| 318 |
+
raise RuntimeError(f"Wrong proj layer. Got {proj_layer}")
|
| 319 |
+
|
| 320 |
+
if target_mlp == 'identity':
|
| 321 |
+
self.target_mlp = nn.Sequential(
|
| 322 |
+
nn.Identity(),
|
| 323 |
+
)
|
| 324 |
+
elif target_mlp == 'norm':
|
| 325 |
+
self.target_mlp = nn.Sequential(
|
| 326 |
+
nn.LayerNorm(input_channels, eps=1e-6, elementwise_affine=False),
|
| 327 |
+
)
|
| 328 |
+
self.init_weight()
|
| 329 |
+
|
| 330 |
+
def init_weight(self):
|
| 331 |
+
self.proj_in.apply(init_weights)
|
| 332 |
+
self.blocks.apply(init_weights)
|
| 333 |
+
self.proj_out.apply(init_weights)
|
| 334 |
+
self.target_mlp.apply(init_weights)
|
| 335 |
+
|
| 336 |
+
def rot_pos_emb(self, grid_thw, max_seq_len):
|
| 337 |
+
pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
|
| 338 |
+
for idx, (t, h, w) in enumerate(grid_thw):
|
| 339 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 340 |
+
hpos_ids = hpos_ids.flatten()
|
| 341 |
+
|
| 342 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 343 |
+
wpos_ids = wpos_ids.flatten()
|
| 344 |
+
|
| 345 |
+
current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
|
| 346 |
+
pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
|
| 347 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 348 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 349 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
|
| 350 |
+
return rotary_pos_emb
|
| 351 |
+
|
| 352 |
+
def forward(self, x, grid_thw):
|
| 353 |
+
x = self.model(x, grid_thw=grid_thw)
|
| 354 |
+
|
| 355 |
+
x = x_target = self.target_mlp(x)
|
| 356 |
+
|
| 357 |
+
x = F.linear(x,
|
| 358 |
+
self.proj_in.weight.view(self.proj_in.weight.shape[0], -1),
|
| 359 |
+
self.proj_in.bias)
|
| 360 |
+
|
| 361 |
+
new_grid_thw = torch.as_tensor([[t, h // 2, w // 2] for t, h, w in grid_thw])
|
| 362 |
+
|
| 363 |
+
seq_lens = [t_i * h_i * w_i for t_i, h_i, w_i in new_grid_thw]
|
| 364 |
+
max_seq_len = max(seq_lens)
|
| 365 |
+
|
| 366 |
+
x = rearrange(x, '(b h w) c -> b (h w) c', h=new_grid_thw[0, 1], w=new_grid_thw[0, 2])
|
| 367 |
+
|
| 368 |
+
rotary_pos_emb = self.rot_pos_emb(new_grid_thw, max_seq_len)
|
| 369 |
+
|
| 370 |
+
for blk in self.blocks:
|
| 371 |
+
x = blk(x, rotary_pos_emb=rotary_pos_emb)
|
| 372 |
+
|
| 373 |
+
x = self.proj_out(x) # [b, max_length, d]
|
| 374 |
+
|
| 375 |
+
t, h, w = new_grid_thw[0]
|
| 376 |
+
b = len(grid_thw)
|
| 377 |
+
x = rearrange(x, 'b (h w) c ->b c h w', b=b, h=h, w=w)
|
| 378 |
+
x_target = rearrange(x_target, '(b h w) c ->b c h w', b=b, h=h, w=w)
|
| 379 |
+
return x, x_target
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class SemanticDecoder(nn.Module):
|
| 383 |
+
def __init__(self,
|
| 384 |
+
z_channels=4,
|
| 385 |
+
embed_dim=1280,
|
| 386 |
+
num_blocks=2,
|
| 387 |
+
output_channels=1280,
|
| 388 |
+
attn_implementation='xformers',
|
| 389 |
+
proj_layer='linear_norm'):
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.proj_in = nn.Linear(z_channels, embed_dim)
|
| 392 |
+
|
| 393 |
+
self.output_channels = output_channels
|
| 394 |
+
config = Qwen2VLVisionConfig(depth=num_blocks, embed_dim=embed_dim)
|
| 395 |
+
|
| 396 |
+
self.blocks = nn.ModuleList(
|
| 397 |
+
[Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
|
| 398 |
+
)
|
| 399 |
+
head_dim = config.embed_dim // config.num_heads
|
| 400 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
| 401 |
+
|
| 402 |
+
if proj_layer == 'norm_linear':
|
| 403 |
+
self.proj_out = nn.Sequential(
|
| 404 |
+
nn.LayerNorm(embed_dim),
|
| 405 |
+
nn.Linear(embed_dim, output_channels),
|
| 406 |
+
)
|
| 407 |
+
elif proj_layer == 'linear':
|
| 408 |
+
self.proj_out = nn.Sequential(
|
| 409 |
+
nn.Linear(embed_dim, output_channels)
|
| 410 |
+
)
|
| 411 |
+
elif proj_layer == 'mlp':
|
| 412 |
+
self.proj_out = nn.Sequential(
|
| 413 |
+
nn.Linear(embed_dim, embed_dim),
|
| 414 |
+
nn.Tanh(),
|
| 415 |
+
nn.Linear(embed_dim, output_channels),
|
| 416 |
+
)
|
| 417 |
+
elif proj_layer == 'linear_norm':
|
| 418 |
+
self.proj_out = nn.Sequential(
|
| 419 |
+
nn.Linear(embed_dim, output_channels),
|
| 420 |
+
nn.LayerNorm(output_channels),
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
self.apply(init_weights)
|
| 424 |
+
|
| 425 |
+
@property
|
| 426 |
+
def last_layer(self):
|
| 427 |
+
return self.proj_out[-1].weight
|
| 428 |
+
|
| 429 |
+
def rot_pos_emb(self, grid_thw, max_seq_len):
|
| 430 |
+
pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
|
| 431 |
+
for idx, (t, h, w) in enumerate(grid_thw):
|
| 432 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 433 |
+
hpos_ids = hpos_ids.flatten()
|
| 434 |
+
|
| 435 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 436 |
+
wpos_ids = wpos_ids.flatten()
|
| 437 |
+
|
| 438 |
+
current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
|
| 439 |
+
pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
|
| 440 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 441 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 442 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
|
| 443 |
+
return rotary_pos_emb
|
| 444 |
+
|
| 445 |
+
def forward(self, z: torch.Tensor):
|
| 446 |
+
x = z
|
| 447 |
+
|
| 448 |
+
b, c, h, w = x.shape
|
| 449 |
+
|
| 450 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
| 451 |
+
|
| 452 |
+
grid_thw = torch.as_tensor([[1, h, w] for _ in range(b)])
|
| 453 |
+
seq_lens = [t * h * w for t, h, w in grid_thw]
|
| 454 |
+
max_seq_len = max(seq_lens)
|
| 455 |
+
|
| 456 |
+
x = self.proj_in(x)
|
| 457 |
+
|
| 458 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw, max_seq_len)
|
| 459 |
+
|
| 460 |
+
for blk in self.blocks:
|
| 461 |
+
x = blk(x, rotary_pos_emb=rotary_pos_emb)
|
| 462 |
+
|
| 463 |
+
x = self.proj_out(x)
|
| 464 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| 465 |
+
return x
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class DualViTokPretrainModel(PreTrainedModel):
|
| 469 |
+
"""
|
| 470 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 471 |
+
models.
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
config_class = DualViTokConfig
|
| 475 |
+
base_model_prefix = "dualvitok"
|
| 476 |
+
main_input_name = "pixel_values"
|
| 477 |
+
_no_split_modules = ["BatchQwen2VLVisionBlock", "MoVQResnetBlock", "MoVQAttnBlock", "MoVQResnetTemporalBlock"]
|
| 478 |
+
_supports_flash_attn_2 = True
|
| 479 |
+
_supports_sdpa = True
|
| 480 |
+
_supports_cache_class = True
|
| 481 |
+
_supports_static_cache = True
|
| 482 |
+
|
| 483 |
+
def _init_weights(self, module):
|
| 484 |
+
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
|
| 485 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 486 |
+
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
|
| 487 |
+
elif isinstance(module, nn.Linear):
|
| 488 |
+
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
| 489 |
+
if module.bias is not None:
|
| 490 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
| 491 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 492 |
+
nn.init.uniform_(module.bias, -bound, bound)
|
| 493 |
+
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
| 494 |
+
nn.init.constant_(module.weight, 1)
|
| 495 |
+
nn.init.constant_(module.bias, 0)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class DualViTok(DualViTokPretrainModel):
|
| 499 |
+
def __init__(self, config: DualViTokConfig):
|
| 500 |
+
super().__init__(config)
|
| 501 |
+
self.config = config
|
| 502 |
+
|
| 503 |
+
self._semantic_channel = config.semantic_encoder.z_channels
|
| 504 |
+
self._pixel_channel = config.pixel_encoder.z_channels
|
| 505 |
+
|
| 506 |
+
self.semantic_encoder = SemanticEncoder(
|
| 507 |
+
semantic_encoder=config.semantic_encoder.pretrained_semantic_encoder,
|
| 508 |
+
z_channels=config.semantic_encoder.z_channels,
|
| 509 |
+
num_blocks=config.semantic_encoder.num_blocks,
|
| 510 |
+
embed_dim=config.semantic_encoder.embed_dim,
|
| 511 |
+
proj_layer=config.semantic_encoder.out_layer,
|
| 512 |
+
attn_implementation=config.attn_implementation,
|
| 513 |
+
target_mlp=config.semantic_encoder.target_mlp, )
|
| 514 |
+
self.semantic_decoder = SemanticDecoder(
|
| 515 |
+
z_channels=config.semantic_decoder.z_channels,
|
| 516 |
+
embed_dim=config.semantic_decoder.embed_dim,
|
| 517 |
+
num_blocks=config.semantic_decoder.num_blocks,
|
| 518 |
+
output_channels=config.semantic_decoder.out_channels,
|
| 519 |
+
attn_implementation=config.attn_implementation,
|
| 520 |
+
proj_layer=config.semantic_decoder.out_layer,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
if config.semantic_quantizer_type.lower() == 'simvq':
|
| 524 |
+
self.semantic_quantizer = SimVQ(
|
| 525 |
+
dim=config.semantic_encoder.z_channels,
|
| 526 |
+
codebook_size=config.semantic_quantizer_codebook_size,
|
| 527 |
+
)
|
| 528 |
+
elif config.semantic_quantizer_type.lower() == 'vq':
|
| 529 |
+
raise NotImplementedError
|
| 530 |
+
self.semantic_quantizer = VQ(
|
| 531 |
+
dim=config.semantic_encoder.z_channels,
|
| 532 |
+
codebook_size=config.semantic_quantizer_codebook_size,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
self.pixel_encoder = MoVQEncoder(config.pixel_encoder)
|
| 536 |
+
self.pixel_quant_conv = nn.Conv2d(config.pixel_encoder.z_channels, config.pixel_encoder.embed_dim, 1)
|
| 537 |
+
|
| 538 |
+
if config.pixel_quantizer_type.lower() == 'simvq':
|
| 539 |
+
self.pixel_quantizer = SimVQ(
|
| 540 |
+
dim=config.pixel_encoder.z_channels,
|
| 541 |
+
codebook_size=config.pixel_quantizer_codebook_size,
|
| 542 |
+
)
|
| 543 |
+
elif config.pixel_quantizer_type.lower() == 'vq':
|
| 544 |
+
raise NotImplementedError
|
| 545 |
+
self.pixel_quantizer = VQ(
|
| 546 |
+
dim=config.pixel_encoder.z_channels,
|
| 547 |
+
codebook_size=config.pixel_quantizer_codebook_size,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
self.pixel_post_quant_conv = nn.Conv2d(config.pixel_decoder.embed_dim,
|
| 551 |
+
config.pixel_decoder.z_channels, 1)
|
| 552 |
+
|
| 553 |
+
self.pixel_decoder = MoVQDecoder(config.pixel_decoder)
|
| 554 |
+
|
| 555 |
+
self.scaling_layer = ScalingLayerForQwen2ViT()
|
| 556 |
+
|
| 557 |
+
@property
|
| 558 |
+
def device(self):
|
| 559 |
+
return get_parameter_device(self)
|
| 560 |
+
|
| 561 |
+
@property
|
| 562 |
+
def dtype(self):
|
| 563 |
+
return get_parameter_dtype(self)
|
| 564 |
+
|
| 565 |
+
@property
|
| 566 |
+
def pixel_channel(self):
|
| 567 |
+
return self._pixel_channel
|
| 568 |
+
|
| 569 |
+
@property
|
| 570 |
+
def semantic_channel(self):
|
| 571 |
+
return self._semantic_channel
|
| 572 |
+
|
| 573 |
+
def encode(self, image: torch.FloatTensor):
|
| 574 |
+
scale_output = self.scaling_layer(image)
|
| 575 |
+
image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
|
| 576 |
+
|
| 577 |
+
h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
|
| 578 |
+
quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
|
| 579 |
+
|
| 580 |
+
h_pixel = self.pixel_encoder(image_gen)
|
| 581 |
+
h_pixel = self.pixel_quant_conv(h_pixel)
|
| 582 |
+
|
| 583 |
+
quant_pixel, emb_loss_pixel, info_pixel = self.pixel_quantizer(h_pixel.float())
|
| 584 |
+
|
| 585 |
+
return (quant_semantic, emb_loss_semantic, info_semantic, target_semantic), \
|
| 586 |
+
(quant_pixel, emb_loss_pixel, info_pixel)
|
| 587 |
+
|
| 588 |
+
def encode_code(self, *args, **kwargs):
|
| 589 |
+
(_, _, semantic_indices, _), \
|
| 590 |
+
(_, _, pixel_indices) = self.encode(*args, **kwargs)
|
| 591 |
+
return semantic_indices, pixel_indices
|
| 592 |
+
|
| 593 |
+
def indices_to_codes(self, semantic_indices, pixel_indices):
|
| 594 |
+
quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
|
| 595 |
+
quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
|
| 596 |
+
return quant_semantic, quant_pixel
|
| 597 |
+
|
| 598 |
+
def encode_semantic(self, image: torch.FloatTensor):
|
| 599 |
+
scale_output = self.scaling_layer(image)
|
| 600 |
+
image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
|
| 601 |
+
|
| 602 |
+
h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
|
| 603 |
+
quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
|
| 604 |
+
return quant_semantic, emb_loss_semantic, info_semantic, target_semantic
|
| 605 |
+
|
| 606 |
+
def merge_quants(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor):
|
| 607 |
+
quant_semantic_resized = F.interpolate(
|
| 608 |
+
quant_semantic, quant_pixel.shape[-2:], mode='bicubic'
|
| 609 |
+
).to(quant_semantic.dtype)
|
| 610 |
+
quant_semantic = quant_semantic_resized
|
| 611 |
+
|
| 612 |
+
quant = torch.cat([quant_semantic, quant_pixel], dim=1)
|
| 613 |
+
|
| 614 |
+
return quant
|
| 615 |
+
|
| 616 |
+
def decode(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor, ):
|
| 617 |
+
quant = self.merge_quants(quant_semantic, quant_pixel)
|
| 618 |
+
quant2 = self.pixel_post_quant_conv(quant)
|
| 619 |
+
x = self.pixel_decoder(quant2, quant)
|
| 620 |
+
return x
|
| 621 |
+
|
| 622 |
+
def decode_code(self, semantic_indices, pixel_indices):
|
| 623 |
+
quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
|
| 624 |
+
quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
|
| 625 |
+
return self.decode(quant_semantic, quant_pixel)
|
| 626 |
+
|
| 627 |
+
def decode_semantic(self, x: List[torch.Tensor]):
|
| 628 |
+
return self.semantic_decoder(x)
|
| 629 |
+
|
| 630 |
+
def forward(self, pixel_values: torch.FloatTensor):
|
| 631 |
+
(quant_semantic, diff_semantic, _, target_semantic), \
|
| 632 |
+
(quant_pixel, diff_pixel, _) = self.encode(pixel_values)
|
| 633 |
+
dec = self.decode(quant_semantic, quant_pixel)
|
| 634 |
+
dec_semantic = self.decode_semantic(quant_semantic)
|
| 635 |
+
return (dec_semantic, diff_semantic, target_semantic), (dec, diff_pixel)
|
| 636 |
+
|
| 637 |
+
def build_sdxl_decoder(self, path='ILLUME-MLLM/dualvitok-sdxl-decoder',
|
| 638 |
+
image_processor=None,
|
| 639 |
+
torch_dtype=torch.float16,
|
| 640 |
+
add_watermarker=False,
|
| 641 |
+
device='cuda',
|
| 642 |
+
):
|
| 643 |
+
from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
|
| 644 |
+
|
| 645 |
+
if image_processor is None:
|
| 646 |
+
image_processor = AutoImageProcessor.from_pretrained('ILLUME-MLLM/dualvitok', trust_remote_code=True)
|
| 647 |
+
|
| 648 |
+
return StableDiffusionXLDecoderPipeline.from_pretrained(path,
|
| 649 |
+
torch_dtype=torch_dtype,
|
| 650 |
+
add_watermarker=add_watermarker,
|
| 651 |
+
vq_model=self,
|
| 652 |
+
vq_image_processor=image_processor).to(device)
|
| 653 |
+
|
modeling_movqgan.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" MoVQ model """
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.utils.checkpoint import checkpoint
|
| 11 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
+
|
| 13 |
+
from .configuration_movqgan import MoVQConfig
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import xformers.ops as xops
|
| 17 |
+
|
| 18 |
+
is_xformers_available = True
|
| 19 |
+
except Exception as e:
|
| 20 |
+
is_xformers_available = False
|
| 21 |
+
|
| 22 |
+
if torch.__version__ > "2.1.2":
|
| 23 |
+
IS_SDPA_AVAILABLE = True
|
| 24 |
+
else:
|
| 25 |
+
IS_SDPA_AVAILABLE = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MoVQActivation(nn.Module):
|
| 29 |
+
|
| 30 |
+
def __init__(self):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
def __call__(self, x: torch.Tensor):
|
| 34 |
+
return x * torch.sigmoid(x)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MoVQUpsample(nn.Module):
|
| 38 |
+
|
| 39 |
+
def __init__(self, in_channels: int):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.conv = nn.Conv2d(
|
| 42 |
+
in_channels,
|
| 43 |
+
in_channels,
|
| 44 |
+
kernel_size=3,
|
| 45 |
+
stride=1,
|
| 46 |
+
padding=1,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def forward(self, x: torch.Tensor):
|
| 50 |
+
x = F.interpolate(x.float(), scale_factor=2.0, mode="nearest").to(x.dtype)
|
| 51 |
+
x = self.conv(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class DCDownBlock2d(nn.Module):
|
| 56 |
+
def __init__(self, in_channels: int, out_channels: int = None, downsample: bool = True,
|
| 57 |
+
shortcut: bool = True) -> None:
|
| 58 |
+
super().__init__()
|
| 59 |
+
out_channels = out_channels if out_channels else in_channels
|
| 60 |
+
|
| 61 |
+
self.downsample = downsample
|
| 62 |
+
self.factor = 2
|
| 63 |
+
self.stride = 1 if downsample else 2
|
| 64 |
+
self.group_size = in_channels * self.factor ** 2 // out_channels
|
| 65 |
+
self.shortcut = shortcut
|
| 66 |
+
|
| 67 |
+
out_ratio = self.factor ** 2
|
| 68 |
+
if downsample:
|
| 69 |
+
assert out_channels % out_ratio == 0
|
| 70 |
+
out_channels = out_channels // out_ratio
|
| 71 |
+
|
| 72 |
+
self.conv = nn.Conv2d(
|
| 73 |
+
in_channels,
|
| 74 |
+
out_channels,
|
| 75 |
+
kernel_size=3,
|
| 76 |
+
stride=self.stride,
|
| 77 |
+
padding=1,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
x = self.conv(hidden_states)
|
| 82 |
+
if self.downsample:
|
| 83 |
+
x = F.pixel_unshuffle(x, self.factor)
|
| 84 |
+
|
| 85 |
+
if self.shortcut:
|
| 86 |
+
y = F.pixel_unshuffle(hidden_states, self.factor)
|
| 87 |
+
y = y.unflatten(1, (-1, self.group_size))
|
| 88 |
+
y = y.mean(dim=2)
|
| 89 |
+
hidden_states = x + y
|
| 90 |
+
else:
|
| 91 |
+
hidden_states = x
|
| 92 |
+
|
| 93 |
+
return hidden_states # x + y
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class DCUpBlock2d(nn.Module):
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
in_channels: int,
|
| 100 |
+
out_channels: int = None,
|
| 101 |
+
interpolate: bool = False,
|
| 102 |
+
shortcut: bool = True,
|
| 103 |
+
interpolation_mode: str = "nearest",
|
| 104 |
+
) -> None:
|
| 105 |
+
super().__init__()
|
| 106 |
+
out_channels = out_channels if out_channels else in_channels
|
| 107 |
+
|
| 108 |
+
self.interpolate = interpolate
|
| 109 |
+
self.interpolation_mode = interpolation_mode
|
| 110 |
+
self.shortcut = shortcut
|
| 111 |
+
self.factor = 2
|
| 112 |
+
self.repeats = out_channels * self.factor ** 2 // in_channels
|
| 113 |
+
|
| 114 |
+
out_ratio = self.factor ** 2
|
| 115 |
+
|
| 116 |
+
if not interpolate:
|
| 117 |
+
out_channels = out_channels * out_ratio
|
| 118 |
+
|
| 119 |
+
self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
| 120 |
+
|
| 121 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
if self.interpolate:
|
| 123 |
+
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
|
| 124 |
+
x = self.conv(x)
|
| 125 |
+
else:
|
| 126 |
+
x = self.conv(hidden_states)
|
| 127 |
+
x = F.pixel_shuffle(x, self.factor)
|
| 128 |
+
|
| 129 |
+
if self.shortcut:
|
| 130 |
+
y = hidden_states.repeat_interleave(self.repeats, dim=1)
|
| 131 |
+
y = F.pixel_shuffle(y, self.factor)
|
| 132 |
+
hidden_states = x + y
|
| 133 |
+
else:
|
| 134 |
+
hidden_states = x
|
| 135 |
+
|
| 136 |
+
return hidden_states
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class MoVQDownsample(nn.Module):
|
| 140 |
+
|
| 141 |
+
def __init__(self, in_channels: int):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.conv = nn.Conv2d(
|
| 144 |
+
in_channels,
|
| 145 |
+
in_channels,
|
| 146 |
+
kernel_size=3,
|
| 147 |
+
stride=2,
|
| 148 |
+
padding=0,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def forward(self, x: torch.Tensor):
|
| 152 |
+
pad = (0, 1, 0, 1)
|
| 153 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
| 154 |
+
x = self.conv(x)
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class MoVQSpatialNorm(nn.Module):
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
f_channels: int,
|
| 163 |
+
zq_channels: int,
|
| 164 |
+
norm_layer: nn.Module = nn.GroupNorm,
|
| 165 |
+
add_conv: bool = False,
|
| 166 |
+
num_groups: int = 32,
|
| 167 |
+
eps: float = 1e-6,
|
| 168 |
+
affine: bool = True,
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.norm_layer = norm_layer(
|
| 172 |
+
num_channels=f_channels,
|
| 173 |
+
num_groups=num_groups,
|
| 174 |
+
eps=eps,
|
| 175 |
+
affine=affine,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
self.add_conv = add_conv
|
| 179 |
+
if self.add_conv:
|
| 180 |
+
self.conv = nn.Conv2d(
|
| 181 |
+
zq_channels,
|
| 182 |
+
zq_channels,
|
| 183 |
+
kernel_size=3,
|
| 184 |
+
stride=1,
|
| 185 |
+
padding=1,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.conv_y = nn.Conv2d(
|
| 189 |
+
zq_channels,
|
| 190 |
+
f_channels,
|
| 191 |
+
kernel_size=1,
|
| 192 |
+
stride=1,
|
| 193 |
+
padding=0,
|
| 194 |
+
)
|
| 195 |
+
self.conv_b = nn.Conv2d(
|
| 196 |
+
zq_channels,
|
| 197 |
+
f_channels,
|
| 198 |
+
kernel_size=1,
|
| 199 |
+
stride=1,
|
| 200 |
+
padding=0,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor, zq: torch.Tensor):
|
| 204 |
+
zq = F.interpolate(zq.float(), size=x.shape[-2:], mode="nearest").to(zq.dtype)
|
| 205 |
+
|
| 206 |
+
if self.add_conv:
|
| 207 |
+
zq = self.conv(zq)
|
| 208 |
+
|
| 209 |
+
x = self.norm_layer(x)
|
| 210 |
+
x = x * self.conv_y(zq) + self.conv_b(zq)
|
| 211 |
+
return x
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class MoVQResnetBlock(nn.Module):
|
| 215 |
+
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
in_channels: int,
|
| 219 |
+
out_channels: Optional[int] = None,
|
| 220 |
+
conv_shortcut: bool = False,
|
| 221 |
+
dropout: float = 0.0,
|
| 222 |
+
zq_ch: Optional[int] = None,
|
| 223 |
+
add_conv: bool = False,
|
| 224 |
+
):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.in_channels = in_channels
|
| 227 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 228 |
+
self.out_channels = out_channels
|
| 229 |
+
self.use_conv_shortcut = conv_shortcut
|
| 230 |
+
self.zq_ch = zq_ch
|
| 231 |
+
|
| 232 |
+
if zq_ch is None:
|
| 233 |
+
norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
|
| 234 |
+
self.norm1 = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
|
| 235 |
+
self.norm2 = nn.GroupNorm(num_channels=out_channels, **norm_kwargs)
|
| 236 |
+
else:
|
| 237 |
+
self.norm1 = MoVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
|
| 238 |
+
self.norm2 = MoVQSpatialNorm(out_channels, zq_ch, add_conv=add_conv)
|
| 239 |
+
|
| 240 |
+
self.conv1 = nn.Conv2d(
|
| 241 |
+
in_channels,
|
| 242 |
+
out_channels,
|
| 243 |
+
kernel_size=3,
|
| 244 |
+
stride=1,
|
| 245 |
+
padding=1,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
self.dropout = nn.Dropout(dropout)
|
| 249 |
+
self.conv2 = nn.Conv2d(
|
| 250 |
+
out_channels,
|
| 251 |
+
out_channels,
|
| 252 |
+
kernel_size=3,
|
| 253 |
+
stride=1,
|
| 254 |
+
padding=1,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
self.act = MoVQActivation()
|
| 258 |
+
|
| 259 |
+
if self.in_channels != self.out_channels:
|
| 260 |
+
if self.use_conv_shortcut:
|
| 261 |
+
self.conv_shortcut = nn.Conv2d(
|
| 262 |
+
in_channels,
|
| 263 |
+
out_channels,
|
| 264 |
+
kernel_size=3,
|
| 265 |
+
stride=1,
|
| 266 |
+
padding=1,
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
self.nin_shortcut = nn.Conv2d(
|
| 270 |
+
in_channels,
|
| 271 |
+
out_channels,
|
| 272 |
+
kernel_size=1,
|
| 273 |
+
stride=1,
|
| 274 |
+
padding=0,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
|
| 278 |
+
norm_args = tuple() if self.zq_ch is None else (zq,)
|
| 279 |
+
|
| 280 |
+
h = self.norm1(x, *norm_args)
|
| 281 |
+
h = self.act(h)
|
| 282 |
+
h = self.conv1(h)
|
| 283 |
+
|
| 284 |
+
h = self.norm2(h, *norm_args)
|
| 285 |
+
h = self.act(h)
|
| 286 |
+
h = self.dropout(h)
|
| 287 |
+
h = self.conv2(h)
|
| 288 |
+
|
| 289 |
+
if self.in_channels != self.out_channels:
|
| 290 |
+
if self.use_conv_shortcut:
|
| 291 |
+
x = self.conv_shortcut(x)
|
| 292 |
+
else:
|
| 293 |
+
x = self.nin_shortcut(x)
|
| 294 |
+
|
| 295 |
+
return x + h
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class MoVQAttnBlock(nn.Module):
|
| 299 |
+
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
in_channels: int,
|
| 303 |
+
zq_ch: Optional[int] = None,
|
| 304 |
+
add_conv: bool = False,
|
| 305 |
+
num_heads=1,
|
| 306 |
+
):
|
| 307 |
+
super().__init__()
|
| 308 |
+
self.in_channels = in_channels
|
| 309 |
+
self.zq_ch = zq_ch
|
| 310 |
+
self.num_heads = num_heads
|
| 311 |
+
|
| 312 |
+
if zq_ch is None:
|
| 313 |
+
norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
|
| 314 |
+
self.norm = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
|
| 315 |
+
else:
|
| 316 |
+
self.norm = MoVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
|
| 317 |
+
|
| 318 |
+
self.q = nn.Conv2d(
|
| 319 |
+
in_channels,
|
| 320 |
+
in_channels,
|
| 321 |
+
kernel_size=1,
|
| 322 |
+
stride=1,
|
| 323 |
+
padding=0,
|
| 324 |
+
)
|
| 325 |
+
self.k = nn.Conv2d(
|
| 326 |
+
in_channels,
|
| 327 |
+
in_channels,
|
| 328 |
+
kernel_size=1,
|
| 329 |
+
stride=1,
|
| 330 |
+
padding=0,
|
| 331 |
+
)
|
| 332 |
+
self.v = nn.Conv2d(
|
| 333 |
+
in_channels,
|
| 334 |
+
in_channels,
|
| 335 |
+
kernel_size=1,
|
| 336 |
+
stride=1,
|
| 337 |
+
padding=0,
|
| 338 |
+
)
|
| 339 |
+
self.proj_out = nn.Conv2d(
|
| 340 |
+
in_channels,
|
| 341 |
+
in_channels,
|
| 342 |
+
kernel_size=1,
|
| 343 |
+
stride=1,
|
| 344 |
+
padding=0,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
|
| 348 |
+
# x: [b, c1, h1, w1]
|
| 349 |
+
# zq: [b, c2, h2, w2]
|
| 350 |
+
# attention_mask: [b, 1, h3, w3]
|
| 351 |
+
norm_args = tuple() if self.zq_ch is None else (zq,)
|
| 352 |
+
|
| 353 |
+
# if context is not None:
|
| 354 |
+
# context = F.interpolate(context.float(), size=x.shape[-2:], mode="nearest").to(context.dtype)
|
| 355 |
+
# x = x + self.conv_context(context)
|
| 356 |
+
|
| 357 |
+
nx = self.norm(x, *norm_args)
|
| 358 |
+
q = self.q(nx)
|
| 359 |
+
k = self.k(nx)
|
| 360 |
+
v = self.v(nx)
|
| 361 |
+
|
| 362 |
+
b, c, h, w = q.shape
|
| 363 |
+
if is_xformers_available:
|
| 364 |
+
# If xformers is available, create attn_bias for xops.memory_efficient_attention.
|
| 365 |
+
attn_bias = None
|
| 366 |
+
|
| 367 |
+
v = xops.memory_efficient_attention(
|
| 368 |
+
rearrange(q, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
|
| 369 |
+
rearrange(k, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
|
| 370 |
+
rearrange(v, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
|
| 371 |
+
scale=1.0 / math.sqrt(c // self.num_heads),
|
| 372 |
+
attn_bias=attn_bias,
|
| 373 |
+
)
|
| 374 |
+
v = rearrange(v, 'b (h w) n c -> b (n c) h w', h=h, w=w).contiguous()
|
| 375 |
+
elif IS_SDPA_AVAILABLE:
|
| 376 |
+
# compute attention
|
| 377 |
+
q = rearrange(q, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
|
| 378 |
+
k = rearrange(k, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
|
| 379 |
+
v = rearrange(v, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
|
| 380 |
+
|
| 381 |
+
attn_bias = None
|
| 382 |
+
|
| 383 |
+
v = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
|
| 384 |
+
v = v.transpose(1, 2)
|
| 385 |
+
v = rearrange(v, 'b (h w) n c -> b (n c) h w', h=h, w=w)
|
| 386 |
+
else:
|
| 387 |
+
# compute attention
|
| 388 |
+
q = rearrange(q, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
|
| 389 |
+
k = rearrange(k, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
|
| 390 |
+
v = rearrange(v, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
|
| 391 |
+
|
| 392 |
+
# score = torch.bmm(q.permute(0, 2, 1), k)
|
| 393 |
+
score = torch.einsum('b n c k, b n c l -> b n k l', q, k)
|
| 394 |
+
score = score / math.sqrt(c // self.num_heads)
|
| 395 |
+
|
| 396 |
+
score = F.softmax(score, dim=2)
|
| 397 |
+
|
| 398 |
+
# attend to values
|
| 399 |
+
# v = v.reshape(b, c, h * w)
|
| 400 |
+
# v = torch.bmm(v, score.permute(0, 2, 1))
|
| 401 |
+
v = torch.einsum('b n c l, b n k l -> b n c k', v, score)
|
| 402 |
+
v = v.reshape(b, c, h, w)
|
| 403 |
+
|
| 404 |
+
v = self.proj_out(v)
|
| 405 |
+
|
| 406 |
+
return x + v
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class MoVQVectorQuantizer(nn.Module):
|
| 410 |
+
|
| 411 |
+
def __init__(self, config: MoVQConfig):
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
|
| 414 |
+
self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
|
| 415 |
+
|
| 416 |
+
def forward(self, x: torch.Tensor):
|
| 417 |
+
# b t c h w -> b t h w c
|
| 418 |
+
b, t, c, h, w = x.shape
|
| 419 |
+
x = x.permute(0, 1, 3, 4, 2).contiguous()
|
| 420 |
+
x_flattened = x.view(-1, c)
|
| 421 |
+
|
| 422 |
+
codebook = self.embedding.weight
|
| 423 |
+
|
| 424 |
+
d = torch.sum(x_flattened ** 2, dim=1, keepdim=True) + \
|
| 425 |
+
torch.sum(codebook ** 2, dim=1) - 2 * \
|
| 426 |
+
torch.einsum('bd,dn->bn', x_flattened, codebook.permute(1, 0))
|
| 427 |
+
|
| 428 |
+
indices = torch.argmin(d, dim=1)
|
| 429 |
+
indices = indices.view(b, t, h, w)
|
| 430 |
+
return indices
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class MoVQPretrainedModel(PreTrainedModel):
|
| 434 |
+
"""
|
| 435 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 436 |
+
models.
|
| 437 |
+
"""
|
| 438 |
+
|
| 439 |
+
config_class = MoVQConfig
|
| 440 |
+
base_model_prefix = "movq"
|
| 441 |
+
main_input_name = "pixel_values"
|
| 442 |
+
_no_split_modules = ["MoVQResnetBlock", "MoVQAttnBlock"]
|
| 443 |
+
|
| 444 |
+
def _init_weights(self, module):
|
| 445 |
+
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
|
| 446 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 447 |
+
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
|
| 448 |
+
elif isinstance(module, nn.Linear):
|
| 449 |
+
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
| 450 |
+
if module.bias is not None:
|
| 451 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
| 452 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 453 |
+
nn.init.uniform_(module.bias, -bound, bound)
|
| 454 |
+
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
| 455 |
+
nn.init.constant_(module.weight, 1)
|
| 456 |
+
nn.init.constant_(module.bias, 0)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class MoVQEncoder(nn.Module):
|
| 460 |
+
def __init__(self, config: MoVQConfig):
|
| 461 |
+
super().__init__()
|
| 462 |
+
self.config = config
|
| 463 |
+
self.ch = config.ch
|
| 464 |
+
self.num_resolutions = len(config.ch_mult)
|
| 465 |
+
self.num_res_blocks = config.num_res_blocks
|
| 466 |
+
self.in_channels = config.in_channels
|
| 467 |
+
|
| 468 |
+
# downsampling
|
| 469 |
+
self.conv_in = nn.Conv2d(
|
| 470 |
+
self.in_channels,
|
| 471 |
+
self.ch,
|
| 472 |
+
kernel_size=3,
|
| 473 |
+
stride=1,
|
| 474 |
+
padding=1
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
in_ch_mult = (1,) + tuple(config.ch_mult)
|
| 478 |
+
self.down = nn.ModuleList()
|
| 479 |
+
for i_level in range(self.num_resolutions):
|
| 480 |
+
block = nn.ModuleList()
|
| 481 |
+
attn = nn.ModuleList()
|
| 482 |
+
block_in = config.ch * in_ch_mult[i_level]
|
| 483 |
+
block_out = config.ch * config.ch_mult[i_level]
|
| 484 |
+
for i_block in range(self.num_res_blocks):
|
| 485 |
+
block.append(
|
| 486 |
+
MoVQResnetBlock(
|
| 487 |
+
in_channels=block_in,
|
| 488 |
+
out_channels=block_out,
|
| 489 |
+
dropout=config.dropout,
|
| 490 |
+
)
|
| 491 |
+
)
|
| 492 |
+
block_in = block_out
|
| 493 |
+
if i_level in config.attn_resolutions:
|
| 494 |
+
attn.append(MoVQAttnBlock(block_in))
|
| 495 |
+
|
| 496 |
+
down = nn.Module()
|
| 497 |
+
down.block = block
|
| 498 |
+
down.attn = attn
|
| 499 |
+
if i_level != self.num_resolutions - 1:
|
| 500 |
+
if config.use_dc_up_down_blocks:
|
| 501 |
+
down.downsample = DCDownBlock2d(block_in)
|
| 502 |
+
else:
|
| 503 |
+
down.downsample = MoVQDownsample(block_in)
|
| 504 |
+
|
| 505 |
+
self.down.append(down)
|
| 506 |
+
|
| 507 |
+
# middle
|
| 508 |
+
self.mid = nn.Module()
|
| 509 |
+
self.mid.block_1 = MoVQResnetBlock(
|
| 510 |
+
in_channels=block_in,
|
| 511 |
+
out_channels=block_in,
|
| 512 |
+
dropout=config.dropout,
|
| 513 |
+
)
|
| 514 |
+
self.mid.attn_1 = MoVQAttnBlock(block_in)
|
| 515 |
+
self.mid.block_2 = MoVQResnetBlock(
|
| 516 |
+
in_channels=block_in,
|
| 517 |
+
out_channels=block_in,
|
| 518 |
+
dropout=config.dropout,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# end
|
| 522 |
+
|
| 523 |
+
self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
|
| 524 |
+
|
| 525 |
+
self.act = MoVQActivation()
|
| 526 |
+
|
| 527 |
+
out_z_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
| 528 |
+
self.conv_out = nn.Conv2d(
|
| 529 |
+
block_in,
|
| 530 |
+
out_z_channels,
|
| 531 |
+
kernel_size=3,
|
| 532 |
+
stride=1,
|
| 533 |
+
padding=1,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
self.out_shortcut_average_group_size = block_in // out_z_channels
|
| 537 |
+
|
| 538 |
+
def forward(self, x: torch.Tensor):
|
| 539 |
+
|
| 540 |
+
# downsampling
|
| 541 |
+
h = self.conv_in(x)
|
| 542 |
+
for i_level in range(self.num_resolutions):
|
| 543 |
+
for i_block in range(self.num_res_blocks):
|
| 544 |
+
h = self.down[i_level].block[i_block](h)
|
| 545 |
+
if len(self.down[i_level].attn) > 0:
|
| 546 |
+
h = self.down[i_level].attn[i_block](h)
|
| 547 |
+
|
| 548 |
+
if i_level != self.num_resolutions - 1:
|
| 549 |
+
h = self.down[i_level].downsample(h)
|
| 550 |
+
|
| 551 |
+
h = self.mid.block_1(h)
|
| 552 |
+
h = self.mid.attn_1(h)
|
| 553 |
+
h = self.mid.block_2(h)
|
| 554 |
+
|
| 555 |
+
# end
|
| 556 |
+
h = self.norm_out(h)
|
| 557 |
+
h = self.act(h)
|
| 558 |
+
|
| 559 |
+
if self.config.use_dc_up_down_blocks:
|
| 560 |
+
x = h.unflatten(1, (-1, self.out_shortcut_average_group_size))
|
| 561 |
+
x = x.mean(dim=2)
|
| 562 |
+
h = self.conv_out(h) + x
|
| 563 |
+
else:
|
| 564 |
+
h = self.conv_out(h)
|
| 565 |
+
return h
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class MoVQDecoder(nn.Module):
|
| 569 |
+
def __init__(self, config: MoVQConfig):
|
| 570 |
+
super().__init__()
|
| 571 |
+
self.config = config
|
| 572 |
+
self.ch = config.ch
|
| 573 |
+
self.num_resolutions = len(config.ch_mult)
|
| 574 |
+
self.num_res_blocks = config.num_res_blocks
|
| 575 |
+
|
| 576 |
+
in_ch_mult = (1,) + tuple(config.ch_mult)
|
| 577 |
+
zq_ch = config.embed_dim
|
| 578 |
+
|
| 579 |
+
block_in = config.ch * config.ch_mult[-1]
|
| 580 |
+
|
| 581 |
+
self.in_shortcut_repeats = block_in // config.embed_dim
|
| 582 |
+
|
| 583 |
+
self.conv_in = nn.Conv2d(
|
| 584 |
+
config.z_channels,
|
| 585 |
+
block_in,
|
| 586 |
+
kernel_size=3,
|
| 587 |
+
stride=1,
|
| 588 |
+
padding=1,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# middle
|
| 592 |
+
self.mid = nn.Module()
|
| 593 |
+
self.mid.block_1 = MoVQResnetBlock(
|
| 594 |
+
in_channels=block_in,
|
| 595 |
+
out_channels=block_in,
|
| 596 |
+
dropout=config.dropout,
|
| 597 |
+
zq_ch=zq_ch,
|
| 598 |
+
)
|
| 599 |
+
self.mid.attn_1 = MoVQAttnBlock(block_in, zq_ch)
|
| 600 |
+
self.mid.block_2 = MoVQResnetBlock(
|
| 601 |
+
in_channels=block_in,
|
| 602 |
+
out_channels=block_in,
|
| 603 |
+
dropout=config.dropout,
|
| 604 |
+
zq_ch=zq_ch,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
# upsampling
|
| 608 |
+
self.up = nn.ModuleList()
|
| 609 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 610 |
+
block = nn.ModuleList()
|
| 611 |
+
attn = nn.ModuleList()
|
| 612 |
+
block_out = config.ch * config.ch_mult[i_level]
|
| 613 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 614 |
+
block.append(
|
| 615 |
+
MoVQResnetBlock(
|
| 616 |
+
in_channels=block_in,
|
| 617 |
+
out_channels=block_out,
|
| 618 |
+
dropout=config.dropout,
|
| 619 |
+
zq_ch=zq_ch,
|
| 620 |
+
)
|
| 621 |
+
)
|
| 622 |
+
block_in = block_out
|
| 623 |
+
if i_level in config.attn_resolutions:
|
| 624 |
+
attn.append(MoVQAttnBlock(block_in, zq_ch))
|
| 625 |
+
|
| 626 |
+
up = nn.Module()
|
| 627 |
+
up.block = block
|
| 628 |
+
up.attn = attn
|
| 629 |
+
if i_level != 0:
|
| 630 |
+
if config.use_dc_up_down_blocks:
|
| 631 |
+
up.upsample = DCUpBlock2d(block_in)
|
| 632 |
+
else:
|
| 633 |
+
up.upsample = MoVQUpsample(block_in)
|
| 634 |
+
|
| 635 |
+
self.up.insert(0, up)
|
| 636 |
+
|
| 637 |
+
self.act = MoVQActivation()
|
| 638 |
+
|
| 639 |
+
self.norm_out = MoVQSpatialNorm(block_in, zq_ch)
|
| 640 |
+
self.conv_out = nn.Conv2d(
|
| 641 |
+
block_in,
|
| 642 |
+
config.out_channels,
|
| 643 |
+
kernel_size=3,
|
| 644 |
+
stride=1,
|
| 645 |
+
padding=1,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
@property
|
| 649 |
+
def last_layer(self):
|
| 650 |
+
return self.conv_out.weight
|
| 651 |
+
|
| 652 |
+
def forward(self, z: torch.Tensor, zq: torch.Tensor):
|
| 653 |
+
h = z
|
| 654 |
+
|
| 655 |
+
if self.config.use_dc_up_down_blocks:
|
| 656 |
+
h = h.repeat_interleave(self.in_shortcut_repeats, dim=1)
|
| 657 |
+
h = self.conv_in(z) + h
|
| 658 |
+
else:
|
| 659 |
+
h = self.conv_in(h)
|
| 660 |
+
|
| 661 |
+
# middle
|
| 662 |
+
h = self.mid.block_1(h, zq)
|
| 663 |
+
h = self.mid.attn_1(h, zq)
|
| 664 |
+
h = self.mid.block_2(h, zq)
|
| 665 |
+
|
| 666 |
+
# upsampling
|
| 667 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 668 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 669 |
+
h = self.up[i_level].block[i_block](h, zq)
|
| 670 |
+
if len(self.up[i_level].attn) > 0:
|
| 671 |
+
h = self.up[i_level].attn[i_block](h, zq)
|
| 672 |
+
|
| 673 |
+
if i_level != 0:
|
| 674 |
+
h = self.up[i_level].upsample(h)
|
| 675 |
+
|
| 676 |
+
h = self.norm_out(h, zq)
|
| 677 |
+
h = self.act(h)
|
| 678 |
+
h = self.conv_out(h)
|
| 679 |
+
|
| 680 |
+
return h
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
class Decoder(nn.Module):
|
| 684 |
+
def __init__(self, config: MoVQConfig):
|
| 685 |
+
super().__init__()
|
| 686 |
+
self.config = config
|
| 687 |
+
self.ch = config.ch
|
| 688 |
+
self.num_resolutions = len(config.ch_mult)
|
| 689 |
+
self.num_res_blocks = config.num_res_blocks
|
| 690 |
+
|
| 691 |
+
in_ch_mult = (1,) + tuple(config.ch_mult)
|
| 692 |
+
|
| 693 |
+
block_in = config.ch * config.ch_mult[-1]
|
| 694 |
+
|
| 695 |
+
self.conv_in = nn.Conv2d(
|
| 696 |
+
config.z_channels,
|
| 697 |
+
block_in,
|
| 698 |
+
kernel_size=3,
|
| 699 |
+
stride=1,
|
| 700 |
+
padding=1,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
# middle
|
| 704 |
+
self.mid = nn.Module()
|
| 705 |
+
self.mid.block_1 = MoVQResnetBlock(
|
| 706 |
+
in_channels=block_in,
|
| 707 |
+
out_channels=block_in,
|
| 708 |
+
dropout=config.dropout,
|
| 709 |
+
)
|
| 710 |
+
self.mid.attn_1 = MoVQAttnBlock(block_in)
|
| 711 |
+
self.mid.block_2 = MoVQResnetBlock(
|
| 712 |
+
in_channels=block_in,
|
| 713 |
+
out_channels=block_in,
|
| 714 |
+
dropout=config.dropout,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# upsampling
|
| 718 |
+
self.up = nn.ModuleList()
|
| 719 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 720 |
+
block = nn.ModuleList()
|
| 721 |
+
attn = nn.ModuleList()
|
| 722 |
+
block_out = config.ch * config.ch_mult[i_level]
|
| 723 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 724 |
+
block.append(
|
| 725 |
+
MoVQResnetBlock(
|
| 726 |
+
in_channels=block_in,
|
| 727 |
+
out_channels=block_out,
|
| 728 |
+
dropout=config.dropout,
|
| 729 |
+
)
|
| 730 |
+
)
|
| 731 |
+
block_in = block_out
|
| 732 |
+
if i_level in config.attn_resolutions:
|
| 733 |
+
attn.append(MoVQAttnBlock(block_in))
|
| 734 |
+
|
| 735 |
+
up = nn.Module()
|
| 736 |
+
up.block = block
|
| 737 |
+
up.attn = attn
|
| 738 |
+
if i_level != 0:
|
| 739 |
+
up.upsample = MoVQUpsample(block_in)
|
| 740 |
+
|
| 741 |
+
self.up.insert(0, up)
|
| 742 |
+
|
| 743 |
+
self.act = MoVQActivation()
|
| 744 |
+
|
| 745 |
+
norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
|
| 746 |
+
self.norm_out = nn.GroupNorm(num_channels=block_in, **norm_kwargs)
|
| 747 |
+
self.conv_out = nn.Conv2d(
|
| 748 |
+
block_in,
|
| 749 |
+
config.out_channels,
|
| 750 |
+
kernel_size=3,
|
| 751 |
+
stride=1,
|
| 752 |
+
padding=1,
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
@property
|
| 756 |
+
def last_layer(self):
|
| 757 |
+
return self.conv_out.weight
|
| 758 |
+
|
| 759 |
+
def forward(self, z: torch.Tensor, zq: torch.Tensor):
|
| 760 |
+
h = z
|
| 761 |
+
h = self.conv_in(h)
|
| 762 |
+
|
| 763 |
+
# middle
|
| 764 |
+
h = self.mid.block_1(h)
|
| 765 |
+
h = self.mid.attn_1(h)
|
| 766 |
+
h = self.mid.block_2(h)
|
| 767 |
+
|
| 768 |
+
# upsampling
|
| 769 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 770 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 771 |
+
h = self.up[i_level].block[i_block](h)
|
| 772 |
+
if len(self.up[i_level].attn) > 0:
|
| 773 |
+
h = self.up[i_level].attn[i_block](h)
|
| 774 |
+
|
| 775 |
+
if i_level != 0:
|
| 776 |
+
h = self.up[i_level].upsample(h)
|
| 777 |
+
|
| 778 |
+
h = self.norm_out(h)
|
| 779 |
+
h = self.act(h)
|
| 780 |
+
h = self.conv_out(h)
|
| 781 |
+
|
| 782 |
+
return h
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
class MoVQModel(MoVQPretrainedModel):
|
| 786 |
+
|
| 787 |
+
def __init__(self, config):
|
| 788 |
+
super().__init__(config)
|
| 789 |
+
self.config = config
|
| 790 |
+
|
| 791 |
+
self.encoder = MoVQEncoder(config)
|
| 792 |
+
self.decoder = MoVQDecoder(config)
|
| 793 |
+
self.quantize = MoVQVectorQuantizer(config)
|
| 794 |
+
|
| 795 |
+
self.quant_conv = nn.Conv2d(config.z_channels, config.embed_dim, 1)
|
| 796 |
+
self.post_quant_conv = nn.Conv2d(config.embed_dim, config.z_channels, 1)
|
| 797 |
+
|
| 798 |
+
self.spatial_scale_factor = 2 ** (len(config.ch_mult) - 1)
|
| 799 |
+
|
| 800 |
+
self.post_init()
|
| 801 |
+
|
| 802 |
+
def encode(self, x: torch.Tensor):
|
| 803 |
+
h = self.encoder(x)
|
| 804 |
+
h = self.quant_conv(h)
|
| 805 |
+
codes = self.quantize(h)
|
| 806 |
+
return codes
|
| 807 |
+
|
| 808 |
+
def decode(self, x: torch.Tensor):
|
| 809 |
+
quant = self.quantize.embedding(x.flatten())
|
| 810 |
+
b, h, w, c = quant.shape
|
| 811 |
+
quant = quant.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
|
| 812 |
+
quant2 = self.post_quant_conv(quant)
|
| 813 |
+
image = self.decoder(quant2, quant)
|
| 814 |
+
image = image.reshape(
|
| 815 |
+
b,
|
| 816 |
+
self.config.out_channels,
|
| 817 |
+
h * self.spatial_scale_factor,
|
| 818 |
+
w * self.spatial_scale_factor,
|
| 819 |
+
)
|
| 820 |
+
return image
|
| 821 |
+
|
| 822 |
+
@property
|
| 823 |
+
def device(self):
|
| 824 |
+
return next(self.parameters()).device
|
| 825 |
+
|
| 826 |
+
@property
|
| 827 |
+
def dtype(self):
|
| 828 |
+
return next(self.parameters()).dtype
|
modeling_qwen2vit.py
ADDED
|
@@ -0,0 +1,842 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""PyTorch Qwen2-VL model."""
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
from torch import Tensor
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
import torch.utils.checkpoint
|
| 32 |
+
|
| 33 |
+
from transformers.activations import ACT2FN
|
| 34 |
+
from transformers.cache_utils import Cache, StaticCache
|
| 35 |
+
from transformers.modeling_attn_mask_utils import (
|
| 36 |
+
AttentionMaskConverter,
|
| 37 |
+
)
|
| 38 |
+
from transformers.modeling_outputs import (
|
| 39 |
+
BaseModelOutputWithPast,
|
| 40 |
+
ModelOutput,
|
| 41 |
+
)
|
| 42 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 43 |
+
from transformers.utils import (
|
| 44 |
+
add_start_docstrings,
|
| 45 |
+
add_start_docstrings_to_model_forward,
|
| 46 |
+
is_torch_npu_available,
|
| 47 |
+
is_flash_attn_2_available,
|
| 48 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 49 |
+
logging,
|
| 50 |
+
replace_return_docstrings,
|
| 51 |
+
)
|
| 52 |
+
from .configuration_qwen2vit import Qwen2VLConfig, Qwen2VLVisionConfig
|
| 53 |
+
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 54 |
+
|
| 55 |
+
from einops import rearrange
|
| 56 |
+
|
| 57 |
+
logger = logging.get_logger(__name__)
|
| 58 |
+
|
| 59 |
+
_CONFIG_FOR_DOC = "Qwen2VLConfig"
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
import xformers.ops as xops
|
| 63 |
+
|
| 64 |
+
is_xformers_available = True
|
| 65 |
+
except Exception as e:
|
| 66 |
+
is_xformers_available = False
|
| 67 |
+
|
| 68 |
+
if is_flash_attn_2_available():
|
| 69 |
+
from flash_attn import flash_attn_varlen_func
|
| 70 |
+
|
| 71 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 72 |
+
else:
|
| 73 |
+
flash_attn_varlen_func = None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def init_weights(m):
|
| 77 |
+
if isinstance(m, nn.Linear):
|
| 78 |
+
# we use xavier_uniform following official JAX ViT:
|
| 79 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 80 |
+
if m.bias is not None:
|
| 81 |
+
nn.init.constant_(m.bias, 0)
|
| 82 |
+
elif isinstance(m, nn.nn.LayerNorm):
|
| 83 |
+
nn.init.constant_(m.bias, 0)
|
| 84 |
+
nn.init.constant_(m.weight, 1.0)
|
| 85 |
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
| 86 |
+
w = m.weight.data
|
| 87 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class Qwen2VLCausalLMOutputWithPast(ModelOutput):
|
| 92 |
+
"""
|
| 93 |
+
Base class for Qwen2VL causal language model (or autoregressive) outputs.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 97 |
+
Language modeling loss (for next-token prediction).
|
| 98 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 99 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 100 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 101 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 102 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 103 |
+
|
| 104 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 105 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 106 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 107 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 108 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 109 |
+
|
| 110 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 111 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 112 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 113 |
+
sequence_length)`.
|
| 114 |
+
|
| 115 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 116 |
+
heads.
|
| 117 |
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
| 118 |
+
The rope index difference between sequence length and multimodal rope.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
loss: Optional[torch.FloatTensor] = None
|
| 122 |
+
logits: torch.FloatTensor = None
|
| 123 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 124 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 125 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 126 |
+
rope_deltas: Optional[torch.LongTensor] = None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class Qwen2VLRotaryEmbedding(nn.Module):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
dim=None,
|
| 133 |
+
max_position_embeddings=2048,
|
| 134 |
+
base=10000,
|
| 135 |
+
device=None,
|
| 136 |
+
scaling_factor=1.0,
|
| 137 |
+
rope_type="default",
|
| 138 |
+
config: Optional[Qwen2VLConfig] = None,
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
# TODO (joao): remove the `if` below, only used for BC
|
| 142 |
+
self.rope_kwargs = {}
|
| 143 |
+
if config is None:
|
| 144 |
+
logger.warning_once(
|
| 145 |
+
"`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
| 146 |
+
"`config` argument. All other arguments will be removed in v4.46"
|
| 147 |
+
)
|
| 148 |
+
self.rope_kwargs = {
|
| 149 |
+
"rope_type": rope_type,
|
| 150 |
+
"factor": scaling_factor,
|
| 151 |
+
"dim": dim,
|
| 152 |
+
"base": base,
|
| 153 |
+
"max_position_embeddings": max_position_embeddings,
|
| 154 |
+
}
|
| 155 |
+
self.rope_type = rope_type
|
| 156 |
+
self.max_seq_len_cached = max_position_embeddings
|
| 157 |
+
self.original_max_seq_len = max_position_embeddings
|
| 158 |
+
else:
|
| 159 |
+
# BC: "rope_type" was originally "type"
|
| 160 |
+
if config.rope_scaling is not None:
|
| 161 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 162 |
+
else:
|
| 163 |
+
self.rope_type = "default"
|
| 164 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 165 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 166 |
+
|
| 167 |
+
self.config = config
|
| 168 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 169 |
+
|
| 170 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
| 171 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 172 |
+
self.original_inv_freq = self.inv_freq
|
| 173 |
+
|
| 174 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
| 175 |
+
"""
|
| 176 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 177 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
| 178 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 179 |
+
"""
|
| 180 |
+
seq_len = torch.max(position_ids) + 1
|
| 181 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 182 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 183 |
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
| 184 |
+
)
|
| 185 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
| 186 |
+
self.max_seq_len_cached = seq_len
|
| 187 |
+
|
| 188 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
| 189 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 190 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 191 |
+
|
| 192 |
+
@torch.no_grad()
|
| 193 |
+
def forward(self, x, position_ids):
|
| 194 |
+
if "dynamic" in self.rope_type:
|
| 195 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 196 |
+
|
| 197 |
+
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
|
| 198 |
+
# So we expand the inv_freq to shape (3, ...)
|
| 199 |
+
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
| 200 |
+
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
| 201 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 202 |
+
device_type = x.device.type
|
| 203 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 204 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 205 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
| 206 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 207 |
+
cos = emb.cos()
|
| 208 |
+
sin = emb.sin()
|
| 209 |
+
|
| 210 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 211 |
+
cos = cos * self.attention_scaling
|
| 212 |
+
sin = sin * self.attention_scaling
|
| 213 |
+
|
| 214 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 218 |
+
def rotate_half(x):
|
| 219 |
+
"""Rotates half the hidden dims of the input."""
|
| 220 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 221 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 222 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
| 226 |
+
orig_dtype = tensor.dtype
|
| 227 |
+
tensor = tensor.float()
|
| 228 |
+
cos = freqs.cos()
|
| 229 |
+
sin = freqs.sin()
|
| 230 |
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
| 231 |
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
| 232 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
| 233 |
+
output = output.to(orig_dtype)
|
| 234 |
+
return output
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def apply_rotary_pos_emb_vision_batch(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
| 238 |
+
orig_dtype = tensor.dtype
|
| 239 |
+
tensor = tensor.float()
|
| 240 |
+
cos = freqs.cos()
|
| 241 |
+
sin = freqs.sin()
|
| 242 |
+
cos = cos.repeat(1, 1, 1, 2).float()
|
| 243 |
+
sin = sin.repeat(1, 1, 1, 2).float()
|
| 244 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
| 245 |
+
output = output.to(orig_dtype)
|
| 246 |
+
return output
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 250 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 251 |
+
super().__init__()
|
| 252 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
| 253 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 254 |
+
|
| 255 |
+
def forward(self, seqlen: int, scale_factor: float = 1.0) -> torch.Tensor:
|
| 256 |
+
# 使用 scale_factor 动态调整 inv_freq
|
| 257 |
+
scaled_inv_freq = self.inv_freq * scale_factor
|
| 258 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 259 |
+
freqs = torch.outer(seq, scaled_inv_freq)
|
| 260 |
+
return freqs
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class PatchEmbed(nn.Module):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
patch_size: int = 14,
|
| 267 |
+
temporal_patch_size: int = 2,
|
| 268 |
+
in_channels: int = 3,
|
| 269 |
+
embed_dim: int = 1152,
|
| 270 |
+
) -> None:
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.patch_size = patch_size
|
| 273 |
+
self.temporal_patch_size = temporal_patch_size
|
| 274 |
+
self.in_channels = in_channels
|
| 275 |
+
self.embed_dim = embed_dim
|
| 276 |
+
|
| 277 |
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
| 278 |
+
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
|
| 279 |
+
|
| 280 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 281 |
+
target_dtype = self.proj.weight.dtype
|
| 282 |
+
if is_torch_npu_available():
|
| 283 |
+
# if True:
|
| 284 |
+
hidden_states = F.linear(hidden_states, self.proj.weight.view(self.embed_dim, -1))
|
| 285 |
+
else:
|
| 286 |
+
hidden_states = hidden_states.view(
|
| 287 |
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
| 288 |
+
)
|
| 289 |
+
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
| 290 |
+
return hidden_states
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class PatchMerger(nn.Module):
|
| 294 |
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
| 295 |
+
super().__init__()
|
| 296 |
+
self.hidden_size = context_dim * (spatial_merge_size ** 2)
|
| 297 |
+
self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
|
| 298 |
+
self.mlp = nn.Sequential(
|
| 299 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 300 |
+
nn.GELU(),
|
| 301 |
+
nn.Linear(self.hidden_size, dim),
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def forward(self, x: torch.Tensor, grid_thw) -> torch.Tensor:
|
| 305 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class VisionMlp(nn.Module):
|
| 310 |
+
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
| 313 |
+
self.act = ACT2FN[hidden_act]
|
| 314 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
| 315 |
+
|
| 316 |
+
def forward(self, x) -> torch.Tensor:
|
| 317 |
+
return self.fc2(self.act(self.fc1(x)))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class VisionAttention(nn.Module):
|
| 321 |
+
def __init__(self, dim: int, num_heads: int = 16, ) -> None:
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.num_heads = num_heads
|
| 324 |
+
self.head_dim = dim // num_heads
|
| 325 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 326 |
+
self.proj = nn.Linear(dim, dim)
|
| 327 |
+
|
| 328 |
+
def forward(
|
| 329 |
+
self,
|
| 330 |
+
hidden_states: torch.Tensor,
|
| 331 |
+
cu_seqlens: torch.Tensor,
|
| 332 |
+
rotary_pos_emb: torch.Tensor = None
|
| 333 |
+
) -> torch.Tensor:
|
| 334 |
+
seq_length = hidden_states.shape[0]
|
| 335 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 336 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 337 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 338 |
+
|
| 339 |
+
attention_mask = torch.full(
|
| 340 |
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
| 341 |
+
)
|
| 342 |
+
for i in range(1, len(cu_seqlens)):
|
| 343 |
+
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0
|
| 344 |
+
|
| 345 |
+
q = q.transpose(0, 1)
|
| 346 |
+
k = k.transpose(0, 1)
|
| 347 |
+
v = v.transpose(0, 1)
|
| 348 |
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
| 349 |
+
attn_weights = attn_weights + attention_mask
|
| 350 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 351 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 352 |
+
attn_output = attn_output.transpose(0, 1)
|
| 353 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
| 354 |
+
attn_output = self.proj(attn_output)
|
| 355 |
+
return attn_output
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class BatchVisionAttention(nn.Module):
|
| 359 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 360 |
+
super().__init__()
|
| 361 |
+
self.num_heads = num_heads
|
| 362 |
+
self.head_dim = dim // num_heads
|
| 363 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 364 |
+
self.proj = nn.Linear(dim, dim)
|
| 365 |
+
|
| 366 |
+
def forward(
|
| 367 |
+
self,
|
| 368 |
+
hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
|
| 369 |
+
attention_mask: torch.Tensor, # [batch_size, 1, 1, seq_len]
|
| 370 |
+
rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
|
| 371 |
+
) -> torch.Tensor:
|
| 372 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 373 |
+
|
| 374 |
+
q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0,
|
| 375 |
+
3, 1,
|
| 376 |
+
4).unbind(
|
| 377 |
+
0)
|
| 378 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
| 379 |
+
|
| 380 |
+
if rotary_pos_emb is not None:
|
| 381 |
+
rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
|
| 382 |
+
q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
|
| 383 |
+
k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
|
| 384 |
+
|
| 385 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 386 |
+
if attention_mask is not None:
|
| 387 |
+
attn_weights = attn_weights + attention_mask
|
| 388 |
+
|
| 389 |
+
# Softmax
|
| 390 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 391 |
+
|
| 392 |
+
attn_output = torch.matmul(attn_weights, v) # [batch_size, num_heads, seq_len, head_dim]
|
| 393 |
+
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)
|
| 394 |
+
return self.proj(attn_output)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class VisionXformerAttention(nn.Module):
|
| 398 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 399 |
+
super().__init__()
|
| 400 |
+
self.num_heads = num_heads
|
| 401 |
+
self.head_dim = dim // num_heads
|
| 402 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 403 |
+
self.proj = nn.Linear(dim, dim)
|
| 404 |
+
|
| 405 |
+
def forward(
|
| 406 |
+
self,
|
| 407 |
+
hidden_states: torch.Tensor,
|
| 408 |
+
cu_seqlens: torch.Tensor,
|
| 409 |
+
rotary_pos_emb: torch.Tensor = None
|
| 410 |
+
) -> torch.Tensor:
|
| 411 |
+
seq_length = hidden_states.shape[0]
|
| 412 |
+
|
| 413 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 414 |
+
|
| 415 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
|
| 416 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)
|
| 417 |
+
|
| 418 |
+
seqlens = [cu_seqlens[0]] + [cu_seqlens[i] - cu_seqlens[i - 1] for i in range(1, len(cu_seqlens))]
|
| 419 |
+
attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 420 |
+
|
| 421 |
+
attn_output = xops.memory_efficient_attention(
|
| 422 |
+
q, k, v.unsqueeze(0),
|
| 423 |
+
attn_bias=attn_bias,
|
| 424 |
+
scale=1.0 / math.sqrt(self.head_dim)
|
| 425 |
+
)
|
| 426 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
| 427 |
+
attn_output = self.proj(attn_output)
|
| 428 |
+
return attn_output
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class BatchVisionXformerAttention(nn.Module):
|
| 432 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.num_heads = num_heads
|
| 435 |
+
self.head_dim = dim // num_heads
|
| 436 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 437 |
+
self.proj = nn.Linear(dim, dim)
|
| 438 |
+
|
| 439 |
+
def forward(
|
| 440 |
+
self,
|
| 441 |
+
hidden_states: torch.Tensor,
|
| 442 |
+
attention_mask: torch.Tensor, # [batch_size, 1, 1, seq_len]
|
| 443 |
+
rotary_pos_emb: torch.Tensor = None
|
| 444 |
+
) -> torch.Tensor:
|
| 445 |
+
seq_length = hidden_states.shape[0]
|
| 446 |
+
batch_size, seq_len = hidden_states.shape
|
| 447 |
+
|
| 448 |
+
q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0,
|
| 449 |
+
3, 1,
|
| 450 |
+
4).unbind(
|
| 451 |
+
0)
|
| 452 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
| 453 |
+
|
| 454 |
+
if rotary_pos_emb is not None:
|
| 455 |
+
rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
|
| 456 |
+
q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
|
| 457 |
+
k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
|
| 458 |
+
|
| 459 |
+
attn_output = xops.memory_efficient_attention(
|
| 460 |
+
q, k, v,
|
| 461 |
+
attn_bias=attention_mask,
|
| 462 |
+
scale=1.0 / math.sqrt(self.head_dim)
|
| 463 |
+
)
|
| 464 |
+
attn_output = attn_output.reshape(batch_size, seq_len, -1)
|
| 465 |
+
return self.proj(attn_output)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class VisionFlashAttention2(nn.Module):
|
| 469 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 470 |
+
super().__init__()
|
| 471 |
+
self.num_heads = num_heads
|
| 472 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 473 |
+
self.proj = nn.Linear(dim, dim)
|
| 474 |
+
|
| 475 |
+
def forward(
|
| 476 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
| 477 |
+
) -> torch.Tensor:
|
| 478 |
+
seq_length = hidden_states.shape[0]
|
| 479 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 480 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 481 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 482 |
+
|
| 483 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 484 |
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 485 |
+
seq_length, -1
|
| 486 |
+
)
|
| 487 |
+
attn_output = self.proj(attn_output)
|
| 488 |
+
return attn_output
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class BatchVisionFlashAttention2(nn.Module):
|
| 492 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 493 |
+
super().__init__()
|
| 494 |
+
self.num_heads = num_heads
|
| 495 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 496 |
+
self.proj = nn.Linear(dim, dim)
|
| 497 |
+
|
| 498 |
+
def forward(
|
| 499 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
| 500 |
+
) -> torch.Tensor:
|
| 501 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 502 |
+
|
| 503 |
+
q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1,
|
| 504 |
+
4).unbind(0)
|
| 505 |
+
|
| 506 |
+
if rotary_pos_emb is not None:
|
| 507 |
+
rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
|
| 508 |
+
q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
|
| 509 |
+
k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
|
| 510 |
+
|
| 511 |
+
q = rearrange(q, 'b h l d -> b l h d')
|
| 512 |
+
k = rearrange(k, 'b h l d -> b l h d')
|
| 513 |
+
v = rearrange(v, 'b h l d -> b l h d')
|
| 514 |
+
|
| 515 |
+
attn_output = _flash_attention_forward(q, k, v).reshape(batch_size, seq_len, -1)
|
| 516 |
+
attn_output = self.proj(attn_output)
|
| 517 |
+
return attn_output
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class VisionSdpaAttention(nn.Module):
|
| 521 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 522 |
+
super().__init__()
|
| 523 |
+
self.num_heads = num_heads
|
| 524 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 525 |
+
self.proj = nn.Linear(dim, dim)
|
| 526 |
+
|
| 527 |
+
def forward(
|
| 528 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
| 529 |
+
) -> torch.Tensor:
|
| 530 |
+
seq_length = hidden_states.shape[0]
|
| 531 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 532 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 533 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 534 |
+
|
| 535 |
+
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
| 536 |
+
for i in range(1, len(cu_seqlens)):
|
| 537 |
+
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
|
| 538 |
+
q = q.transpose(0, 1)
|
| 539 |
+
k = k.transpose(0, 1)
|
| 540 |
+
v = v.transpose(0, 1)
|
| 541 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
| 542 |
+
attn_output = attn_output.transpose(0, 1)
|
| 543 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
| 544 |
+
attn_output = self.proj(attn_output)
|
| 545 |
+
return attn_output
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class BatchVisionSdpaAttention(nn.Module):
|
| 549 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 550 |
+
super().__init__()
|
| 551 |
+
self.num_heads = num_heads
|
| 552 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 553 |
+
self.proj = nn.Linear(dim, dim)
|
| 554 |
+
|
| 555 |
+
def forward(
|
| 556 |
+
self,
|
| 557 |
+
hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
|
| 558 |
+
attention_mask: torch.Tensor = None, # [batch_size, 1, 1, seq_len]
|
| 559 |
+
rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
|
| 560 |
+
) -> torch.Tensor:
|
| 561 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 562 |
+
q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1,
|
| 563 |
+
4).unbind(0)
|
| 564 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
| 565 |
+
|
| 566 |
+
if rotary_pos_emb is not None:
|
| 567 |
+
rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
|
| 568 |
+
q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
|
| 569 |
+
k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
|
| 570 |
+
|
| 571 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
| 572 |
+
attn_output = attn_output.transpose(1, 2)
|
| 573 |
+
attn_output = attn_output.reshape(batch_size, seq_len, -1)
|
| 574 |
+
attn_output = self.proj(attn_output)
|
| 575 |
+
return attn_output
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
QWEN2_VL_VISION_ATTENTION_CLASSES = {
|
| 579 |
+
"eager": VisionAttention,
|
| 580 |
+
"flash_attention_2": VisionFlashAttention2,
|
| 581 |
+
"sdpa": VisionSdpaAttention,
|
| 582 |
+
"xformers": VisionXformerAttention,
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
QWEN2_VL_VISION_BATCH_ATTENTION_CLASSES = {
|
| 586 |
+
"eager": BatchVisionAttention,
|
| 587 |
+
"flash_attention_2": VisionFlashAttention2,
|
| 588 |
+
"sdpa": BatchVisionSdpaAttention,
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class Qwen2VLVisionBlock(nn.Module):
|
| 593 |
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
| 594 |
+
super().__init__()
|
| 595 |
+
|
| 596 |
+
self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
| 597 |
+
self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
| 598 |
+
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
| 599 |
+
|
| 600 |
+
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
|
| 601 |
+
config.embed_dim, num_heads=config.num_heads,
|
| 602 |
+
)
|
| 603 |
+
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
| 604 |
+
|
| 605 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb, grid_thw) -> torch.Tensor:
|
| 606 |
+
hidden_states = hidden_states + self.attn(
|
| 607 |
+
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
| 608 |
+
)
|
| 609 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
| 610 |
+
return hidden_states
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
class Qwen2VLBatchVisionBlock(nn.Module):
|
| 614 |
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
| 615 |
+
super().__init__()
|
| 616 |
+
self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
| 617 |
+
self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
| 618 |
+
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
| 619 |
+
|
| 620 |
+
self.attn = QWEN2_VL_VISION_BATCH_ATTENTION_CLASSES[attn_implementation](
|
| 621 |
+
config.embed_dim, num_heads=config.num_heads,
|
| 622 |
+
)
|
| 623 |
+
self.mlp = VisionMlp(config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
| 624 |
+
|
| 625 |
+
def forward(
|
| 626 |
+
self,
|
| 627 |
+
hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
|
| 628 |
+
attention_mask: torch.Tensor = None, # [batch_size, 1, 1, seq_len]
|
| 629 |
+
rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
|
| 630 |
+
) -> torch.Tensor:
|
| 631 |
+
# Attention
|
| 632 |
+
hidden_states = hidden_states + self.attn(
|
| 633 |
+
self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb
|
| 634 |
+
)
|
| 635 |
+
# MLP
|
| 636 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
| 637 |
+
return hidden_states
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
| 641 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 642 |
+
attention_mask: torch.Tensor,
|
| 643 |
+
sequence_length: int,
|
| 644 |
+
target_length: int,
|
| 645 |
+
dtype: torch.dtype,
|
| 646 |
+
device: torch.device,
|
| 647 |
+
min_dtype: float,
|
| 648 |
+
cache_position: torch.Tensor,
|
| 649 |
+
batch_size: int,
|
| 650 |
+
):
|
| 651 |
+
"""
|
| 652 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 653 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 654 |
+
|
| 655 |
+
Args:
|
| 656 |
+
attention_mask (`torch.Tensor`):
|
| 657 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
| 658 |
+
sequence_length (`int`):
|
| 659 |
+
The sequence length being processed.
|
| 660 |
+
target_length (`int`):
|
| 661 |
+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
| 662 |
+
dtype (`torch.dtype`):
|
| 663 |
+
The dtype to use for the 4D attention mask.
|
| 664 |
+
device (`torch.device`):
|
| 665 |
+
The device to plcae the 4D attention mask on.
|
| 666 |
+
min_dtype (`float`):
|
| 667 |
+
The minimum value representable with the dtype `dtype`.
|
| 668 |
+
cache_position (`torch.Tensor`):
|
| 669 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 670 |
+
batch_size (`torch.Tensor`):
|
| 671 |
+
Batch size.
|
| 672 |
+
"""
|
| 673 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 674 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 675 |
+
causal_mask = attention_mask
|
| 676 |
+
else:
|
| 677 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
| 678 |
+
if sequence_length != 1:
|
| 679 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 680 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 681 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 682 |
+
if attention_mask is not None:
|
| 683 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 684 |
+
mask_length = attention_mask.shape[-1]
|
| 685 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
| 686 |
+
padding_mask = padding_mask == 0
|
| 687 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 688 |
+
padding_mask, min_dtype
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
return causal_mask
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
|
| 695 |
+
class Qwen2RMSNorm(nn.Module):
|
| 696 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 697 |
+
"""
|
| 698 |
+
Qwen2RMSNorm is equivalent to T5nn.LayerNorm
|
| 699 |
+
"""
|
| 700 |
+
super().__init__()
|
| 701 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 702 |
+
self.variance_epsilon = eps
|
| 703 |
+
|
| 704 |
+
def forward(self, hidden_states):
|
| 705 |
+
input_dtype = hidden_states.dtype
|
| 706 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 707 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 708 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 709 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 710 |
+
|
| 711 |
+
def extra_repr(self):
|
| 712 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
|
| 716 |
+
class Qwen2MLP(nn.Module):
|
| 717 |
+
def __init__(self, config):
|
| 718 |
+
super().__init__()
|
| 719 |
+
self.hidden_size = config.hidden_size
|
| 720 |
+
self.intermediate_size = config.intermediate_size
|
| 721 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 722 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 723 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 724 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 725 |
+
|
| 726 |
+
def forward(self, hidden_state):
|
| 727 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 731 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 732 |
+
"""
|
| 733 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 734 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 735 |
+
"""
|
| 736 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 737 |
+
if n_rep == 1:
|
| 738 |
+
return hidden_states
|
| 739 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 740 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class Qwen2VLPreTrainedModel(PreTrainedModel):
|
| 744 |
+
config_class = Qwen2VLConfig
|
| 745 |
+
base_model_prefix = "model"
|
| 746 |
+
supports_gradient_checkpointing = True
|
| 747 |
+
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
|
| 748 |
+
_skip_keys_device_placement = "past_key_values"
|
| 749 |
+
_supports_flash_attn_2 = True
|
| 750 |
+
_supports_sdpa = True
|
| 751 |
+
_supports_cache_class = True
|
| 752 |
+
_supports_static_cache = True
|
| 753 |
+
|
| 754 |
+
def _init_weights(self, module):
|
| 755 |
+
std = self.config.initializer_range
|
| 756 |
+
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
| 757 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 758 |
+
if module.bias is not None:
|
| 759 |
+
module.bias.data.zero_()
|
| 760 |
+
elif isinstance(module, nn.Embedding):
|
| 761 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 762 |
+
if module.padding_idx is not None:
|
| 763 |
+
module.weight.data[module.padding_idx].zero_()
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
| 767 |
+
config_class = Qwen2VLVisionConfig
|
| 768 |
+
_no_split_modules = ["Qwen2VLVisionBlock"]
|
| 769 |
+
|
| 770 |
+
def __init__(self, config) -> None:
|
| 771 |
+
super().__init__(config)
|
| 772 |
+
self.spatial_merge_size = config.spatial_merge_size
|
| 773 |
+
|
| 774 |
+
self.patch_embed = PatchEmbed(
|
| 775 |
+
patch_size=config.patch_size,
|
| 776 |
+
temporal_patch_size=config.temporal_patch_size,
|
| 777 |
+
in_channels=config.in_channels,
|
| 778 |
+
embed_dim=config.embed_dim,
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
head_dim = config.embed_dim // config.num_heads
|
| 782 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
| 783 |
+
|
| 784 |
+
self.blocks = nn.ModuleList(
|
| 785 |
+
[Qwen2VLVisionBlock(config, config.attn_implementation) for _ in range(config.depth)]
|
| 786 |
+
)
|
| 787 |
+
self.merger = PatchMerger(
|
| 788 |
+
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
def get_dtype(self) -> torch.dtype:
|
| 792 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
| 793 |
+
|
| 794 |
+
def get_device(self) -> torch.device:
|
| 795 |
+
return self.blocks[0].mlp.fc2.weight.device
|
| 796 |
+
|
| 797 |
+
def rot_pos_emb(self, grid_thw):
|
| 798 |
+
pos_ids = []
|
| 799 |
+
for t, h, w in grid_thw:
|
| 800 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 801 |
+
hpos_ids = hpos_ids.reshape(
|
| 802 |
+
h // self.spatial_merge_size,
|
| 803 |
+
self.spatial_merge_size,
|
| 804 |
+
w // self.spatial_merge_size,
|
| 805 |
+
self.spatial_merge_size,
|
| 806 |
+
)
|
| 807 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
| 808 |
+
hpos_ids = hpos_ids.flatten()
|
| 809 |
+
|
| 810 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 811 |
+
wpos_ids = wpos_ids.reshape(
|
| 812 |
+
h // self.spatial_merge_size,
|
| 813 |
+
self.spatial_merge_size,
|
| 814 |
+
w // self.spatial_merge_size,
|
| 815 |
+
self.spatial_merge_size,
|
| 816 |
+
)
|
| 817 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
| 818 |
+
wpos_ids = wpos_ids.flatten()
|
| 819 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
| 820 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
| 821 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 822 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 823 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
| 824 |
+
return rotary_pos_emb
|
| 825 |
+
|
| 826 |
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor,
|
| 827 |
+
output_hidden_states=False, org_forward=False, ) -> torch.Tensor:
|
| 828 |
+
hidden_states = self.patch_embed(hidden_states)
|
| 829 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
| 830 |
+
|
| 831 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
| 832 |
+
dim=0, dtype=torch.int32
|
| 833 |
+
)
|
| 834 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 835 |
+
|
| 836 |
+
for blk in self.blocks:
|
| 837 |
+
hidden_states = blk(hidden_states,
|
| 838 |
+
cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb,
|
| 839 |
+
grid_thw=grid_thw)
|
| 840 |
+
|
| 841 |
+
hidden_states = self.merger(hidden_states, grid_thw)
|
| 842 |
+
return hidden_states
|
modeling_rope_utils.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Optional, Tuple
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 20 |
+
from transformers.utils import is_torch_available, logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_torch_available():
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _compute_default_rope_parameters(
|
| 31 |
+
config: Optional[PretrainedConfig] = None,
|
| 32 |
+
device: Optional["torch.device"] = None,
|
| 33 |
+
seq_len: Optional[int] = None,
|
| 34 |
+
**rope_kwargs,
|
| 35 |
+
) -> Tuple["torch.Tensor", float]:
|
| 36 |
+
"""
|
| 37 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 38 |
+
Args:
|
| 39 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 40 |
+
The model configuration.
|
| 41 |
+
device (`torch.device`):
|
| 42 |
+
The device to use for initialization of the inverse frequencies.
|
| 43 |
+
seq_len (`int`, *optional*):
|
| 44 |
+
The current sequence length. Unused for this type of RoPE.
|
| 45 |
+
rope_kwargs (`Dict`, *optional*):
|
| 46 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 49 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 50 |
+
"""
|
| 51 |
+
if config is not None and len(rope_kwargs) > 0:
|
| 52 |
+
raise ValueError(
|
| 53 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
| 54 |
+
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
| 55 |
+
)
|
| 56 |
+
if len(rope_kwargs) > 0:
|
| 57 |
+
base = rope_kwargs["base"]
|
| 58 |
+
dim = rope_kwargs["dim"]
|
| 59 |
+
elif config is not None:
|
| 60 |
+
base = config.rope_theta
|
| 61 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 62 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 63 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 64 |
+
|
| 65 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 66 |
+
|
| 67 |
+
# Compute the inverse frequencies
|
| 68 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
| 69 |
+
return inv_freq, attention_factor
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _compute_linear_scaling_rope_parameters(
|
| 73 |
+
config: Optional[PretrainedConfig] = None,
|
| 74 |
+
device: Optional["torch.device"] = None,
|
| 75 |
+
seq_len: Optional[int] = None,
|
| 76 |
+
**rope_kwargs,
|
| 77 |
+
) -> Tuple["torch.Tensor", float]:
|
| 78 |
+
"""
|
| 79 |
+
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
| 80 |
+
Args:
|
| 81 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 82 |
+
The model configuration.
|
| 83 |
+
device (`torch.device`):
|
| 84 |
+
The device to use for initialization of the inverse frequencies.
|
| 85 |
+
seq_len (`int`, *optional*):
|
| 86 |
+
The current sequence length. Unused for this type of RoPE.
|
| 87 |
+
rope_kwargs (`Dict`, *optional*):
|
| 88 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 89 |
+
Returns:
|
| 90 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 91 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 92 |
+
"""
|
| 93 |
+
if config is not None and len(rope_kwargs) > 0:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
| 96 |
+
f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
| 97 |
+
)
|
| 98 |
+
if len(rope_kwargs) > 0:
|
| 99 |
+
factor = rope_kwargs["factor"]
|
| 100 |
+
elif config is not None:
|
| 101 |
+
factor = config.rope_scaling["factor"]
|
| 102 |
+
|
| 103 |
+
# Gets the default RoPE parameters
|
| 104 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
| 105 |
+
|
| 106 |
+
# Then applies linear scaling to the frequencies.
|
| 107 |
+
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
| 108 |
+
# applying scaling to the inverse frequencies is equivalent.
|
| 109 |
+
inv_freq /= factor
|
| 110 |
+
return inv_freq, attention_factor
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _compute_dynamic_ntk_parameters(
|
| 114 |
+
config: Optional[PretrainedConfig] = None,
|
| 115 |
+
device: Optional["torch.device"] = None,
|
| 116 |
+
seq_len: Optional[int] = None,
|
| 117 |
+
**rope_kwargs,
|
| 118 |
+
) -> Tuple["torch.Tensor", float]:
|
| 119 |
+
"""
|
| 120 |
+
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
| 121 |
+
Args:
|
| 122 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 123 |
+
The model configuration.
|
| 124 |
+
device (`torch.device`):
|
| 125 |
+
The device to use for initialization of the inverse frequencies.
|
| 126 |
+
seq_len (`int`, *optional*):
|
| 127 |
+
The current sequence length, used to update the dynamic RoPE at inference time.
|
| 128 |
+
rope_kwargs (`Dict`, *optional*):
|
| 129 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 130 |
+
Returns:
|
| 131 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 132 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 133 |
+
"""
|
| 134 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
| 135 |
+
if config is not None and len(rope_kwargs) > 0:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
| 138 |
+
f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
| 139 |
+
)
|
| 140 |
+
if len(rope_kwargs) > 0:
|
| 141 |
+
base = rope_kwargs["base"]
|
| 142 |
+
dim = rope_kwargs["dim"]
|
| 143 |
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
| 144 |
+
factor = rope_kwargs["factor"]
|
| 145 |
+
elif config is not None:
|
| 146 |
+
base = config.rope_theta
|
| 147 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 148 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 149 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 150 |
+
max_position_embeddings = config.max_position_embeddings
|
| 151 |
+
factor = config.rope_scaling["factor"]
|
| 152 |
+
|
| 153 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 154 |
+
|
| 155 |
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
| 156 |
+
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
|
| 157 |
+
|
| 158 |
+
# Compute the inverse frequencies
|
| 159 |
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
| 160 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
| 161 |
+
return inv_freq, attention_factor
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _compute_yarn_parameters(
|
| 165 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
| 166 |
+
) -> Tuple["torch.Tensor", float]:
|
| 167 |
+
"""
|
| 168 |
+
Computes the inverse frequencies with NTK scaling. Please refer to the
|
| 169 |
+
[original paper](https://arxiv.org/abs/2309.00071)
|
| 170 |
+
Args:
|
| 171 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 172 |
+
The model configuration.
|
| 173 |
+
device (`torch.device`):
|
| 174 |
+
The device to use for initialization of the inverse frequencies.
|
| 175 |
+
seq_len (`int`, *optional*):
|
| 176 |
+
The current sequence length. Unused for this type of RoPE.
|
| 177 |
+
rope_kwargs (`Dict`, *optional*):
|
| 178 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 179 |
+
Returns:
|
| 180 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 181 |
+
post-processing scaling factor applied to the computed cos/sin.
|
| 182 |
+
"""
|
| 183 |
+
# No need to keep BC with yarn, unreleased when this new pattern was created.
|
| 184 |
+
if len(rope_kwargs) > 0:
|
| 185 |
+
raise ValueError(
|
| 186 |
+
f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
base = config.rope_theta
|
| 190 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 191 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 192 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 193 |
+
max_position_embeddings = config.max_position_embeddings
|
| 194 |
+
factor = config.rope_scaling["factor"]
|
| 195 |
+
|
| 196 |
+
# Sets the attention factor as suggested in the paper
|
| 197 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
| 198 |
+
if attention_factor is None:
|
| 199 |
+
attention_factor = 0.1 * math.log(factor) + 1.0
|
| 200 |
+
|
| 201 |
+
# Optional config options
|
| 202 |
+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
| 203 |
+
beta_fast = config.rope_scaling.get("beta_fast") or 32
|
| 204 |
+
beta_slow = config.rope_scaling.get("beta_slow") or 1
|
| 205 |
+
|
| 206 |
+
# Compute the inverse frequencies
|
| 207 |
+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
| 208 |
+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
| 209 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
| 210 |
+
|
| 211 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
| 212 |
+
"""Find dimension range bounds based on rotations"""
|
| 213 |
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
| 214 |
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
| 215 |
+
return max(low, 0), min(high, dim - 1)
|
| 216 |
+
|
| 217 |
+
def linear_ramp_factor(min, max, dim):
|
| 218 |
+
if min == max:
|
| 219 |
+
max += 0.001 # Prevent singularity
|
| 220 |
+
|
| 221 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 222 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 223 |
+
return ramp_func
|
| 224 |
+
|
| 225 |
+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
| 226 |
+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
| 227 |
+
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
| 228 |
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
| 229 |
+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
| 230 |
+
|
| 231 |
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
| 232 |
+
|
| 233 |
+
# Get n-dimensional rotational scaling corrected for extrapolation
|
| 234 |
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
| 235 |
+
inv_freq = (
|
| 236 |
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
| 237 |
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
return inv_freq, attention_factor
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _compute_longrope_parameters(
|
| 244 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
| 245 |
+
) -> Tuple["torch.Tensor", float]:
|
| 246 |
+
"""
|
| 247 |
+
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
| 248 |
+
[original implementation](https://github.com/microsoft/LongRoPE)
|
| 249 |
+
Args:
|
| 250 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 251 |
+
The model configuration.
|
| 252 |
+
device (`torch.device`):
|
| 253 |
+
The device to use for initialization of the inverse frequencies.
|
| 254 |
+
seq_len (`int`, *optional*):
|
| 255 |
+
The current sequence length. Unused for this type of RoPE.
|
| 256 |
+
rope_kwargs (`Dict`, *optional*):
|
| 257 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 258 |
+
Returns:
|
| 259 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 260 |
+
post-processing scaling factor applied to the computed cos/sin.
|
| 261 |
+
"""
|
| 262 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
| 263 |
+
# No need to keep BC with longrope, unreleased when this new pattern was created.
|
| 264 |
+
if len(rope_kwargs) > 0:
|
| 265 |
+
raise ValueError(
|
| 266 |
+
"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
|
| 267 |
+
f"{rope_kwargs}"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
base = config.rope_theta
|
| 271 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 272 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 273 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 274 |
+
long_factor = config.rope_scaling["long_factor"]
|
| 275 |
+
short_factor = config.rope_scaling["short_factor"]
|
| 276 |
+
factor = config.rope_scaling.get("factor")
|
| 277 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
| 278 |
+
|
| 279 |
+
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
|
| 280 |
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
| 281 |
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
| 282 |
+
if hasattr(config, "original_max_position_embeddings"):
|
| 283 |
+
max_position_embeddings = config.original_max_position_embeddings
|
| 284 |
+
expanded_max_position_embeddings = config.max_position_embeddings
|
| 285 |
+
factor = expanded_max_position_embeddings / max_position_embeddings
|
| 286 |
+
else:
|
| 287 |
+
max_position_embeddings = config.max_position_embeddings
|
| 288 |
+
expanded_max_position_embeddings = max_position_embeddings * factor
|
| 289 |
+
|
| 290 |
+
# Sets the attention factor as suggested in the paper
|
| 291 |
+
if attention_factor is None:
|
| 292 |
+
if factor <= 1.0:
|
| 293 |
+
attention_factor = 1.0
|
| 294 |
+
else:
|
| 295 |
+
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
| 296 |
+
|
| 297 |
+
# Compute the inverse frequencies -- scaled based on the target sequence length
|
| 298 |
+
if expanded_max_position_embeddings > max_position_embeddings:
|
| 299 |
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
| 300 |
+
else:
|
| 301 |
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
| 302 |
+
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
|
| 303 |
+
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
|
| 304 |
+
|
| 305 |
+
return inv_freq, attention_factor
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _compute_llama3_parameters(
|
| 309 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
| 310 |
+
) -> Tuple["torch.Tensor", float]:
|
| 311 |
+
"""
|
| 312 |
+
Computes the inverse frequencies for llama 3.1.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 316 |
+
The model configuration.
|
| 317 |
+
device (`torch.device`):
|
| 318 |
+
The device to use for initialization of the inverse frequencies.
|
| 319 |
+
seq_len (`int`, *optional*):
|
| 320 |
+
The current sequence length. Unused for this type of RoPE.
|
| 321 |
+
rope_kwargs (`Dict`, *optional*):
|
| 322 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 323 |
+
Returns:
|
| 324 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 325 |
+
post-processing scaling factor applied to the computed cos/sin.
|
| 326 |
+
"""
|
| 327 |
+
# Gets the default RoPE parameters
|
| 328 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
| 329 |
+
|
| 330 |
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
| 331 |
+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
| 332 |
+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
| 333 |
+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
| 334 |
+
|
| 335 |
+
low_freq_wavelen = old_context_len / low_freq_factor
|
| 336 |
+
high_freq_wavelen = old_context_len / high_freq_factor
|
| 337 |
+
|
| 338 |
+
wavelen = 2 * math.pi / inv_freq
|
| 339 |
+
# wavelen < high_freq_wavelen: do nothing
|
| 340 |
+
# wavelen > low_freq_wavelen: divide by factor
|
| 341 |
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
| 342 |
+
# otherwise: interpolate between the two, using a smooth factor
|
| 343 |
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
| 344 |
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
| 345 |
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
| 346 |
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
| 347 |
+
|
| 348 |
+
return inv_freq_llama, attention_factor
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
| 352 |
+
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
| 353 |
+
# parameterizations, as long as the callable has the same signature.
|
| 354 |
+
ROPE_INIT_FUNCTIONS = {
|
| 355 |
+
"default": _compute_default_rope_parameters,
|
| 356 |
+
"linear": _compute_linear_scaling_rope_parameters,
|
| 357 |
+
"dynamic": _compute_dynamic_ntk_parameters,
|
| 358 |
+
"yarn": _compute_yarn_parameters,
|
| 359 |
+
"longrope": _compute_longrope_parameters,
|
| 360 |
+
"llama3": _compute_llama3_parameters,
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
|
| 365 |
+
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
| 366 |
+
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
|
| 367 |
+
if "type" in received_keys:
|
| 368 |
+
received_keys -= {"type"}
|
| 369 |
+
required_keys.add("rope_type")
|
| 370 |
+
|
| 371 |
+
missing_keys = required_keys - received_keys
|
| 372 |
+
if missing_keys:
|
| 373 |
+
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
|
| 374 |
+
|
| 375 |
+
if optional_keys is not None:
|
| 376 |
+
unused_keys = received_keys - required_keys - optional_keys
|
| 377 |
+
else:
|
| 378 |
+
unused_keys = received_keys - required_keys
|
| 379 |
+
if unused_keys:
|
| 380 |
+
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def _validate_default_rope_parameters(config: PretrainedConfig):
|
| 384 |
+
rope_scaling = config.rope_scaling
|
| 385 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 386 |
+
required_keys = {"rope_type"}
|
| 387 |
+
received_keys = set(rope_scaling.keys())
|
| 388 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
| 392 |
+
rope_scaling = config.rope_scaling
|
| 393 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 394 |
+
required_keys = {"rope_type", "factor"}
|
| 395 |
+
received_keys = set(rope_scaling.keys())
|
| 396 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
| 397 |
+
|
| 398 |
+
factor = rope_scaling["factor"]
|
| 399 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 400 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
| 404 |
+
rope_scaling = config.rope_scaling
|
| 405 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 406 |
+
required_keys = {"rope_type", "factor"}
|
| 407 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
| 408 |
+
optional_keys = {"original_max_position_embeddings"}
|
| 409 |
+
received_keys = set(rope_scaling.keys())
|
| 410 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
| 411 |
+
|
| 412 |
+
factor = rope_scaling["factor"]
|
| 413 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 414 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _validate_yarn_parameters(config: PretrainedConfig):
|
| 418 |
+
rope_scaling = config.rope_scaling
|
| 419 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 420 |
+
required_keys = {"rope_type", "factor"}
|
| 421 |
+
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
| 422 |
+
received_keys = set(rope_scaling.keys())
|
| 423 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
| 424 |
+
|
| 425 |
+
factor = rope_scaling["factor"]
|
| 426 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 427 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 428 |
+
|
| 429 |
+
attention_factor = rope_scaling.get("attention_factor")
|
| 430 |
+
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
|
| 431 |
+
logger.warning(
|
| 432 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
| 433 |
+
)
|
| 434 |
+
beta_fast = rope_scaling.get("beta_fast")
|
| 435 |
+
if beta_fast is not None and not isinstance(beta_fast, float):
|
| 436 |
+
logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
|
| 437 |
+
beta_slow = rope_scaling.get("beta_slow")
|
| 438 |
+
if beta_slow is not None and not isinstance(beta_slow, float):
|
| 439 |
+
logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
|
| 440 |
+
|
| 441 |
+
if (beta_fast or 32) < (beta_slow or 1):
|
| 442 |
+
logger.warning(
|
| 443 |
+
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
|
| 444 |
+
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _validate_longrope_parameters(config: PretrainedConfig):
|
| 449 |
+
rope_scaling = config.rope_scaling
|
| 450 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 451 |
+
required_keys = {"rope_type", "short_factor", "long_factor"}
|
| 452 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
| 453 |
+
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
| 454 |
+
received_keys = set(rope_scaling.keys())
|
| 455 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
| 456 |
+
|
| 457 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 458 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 459 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 460 |
+
|
| 461 |
+
short_factor = rope_scaling.get("short_factor")
|
| 462 |
+
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
| 463 |
+
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
|
| 464 |
+
if not len(short_factor) == dim // 2:
|
| 465 |
+
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
|
| 466 |
+
|
| 467 |
+
long_factor = rope_scaling.get("long_factor")
|
| 468 |
+
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
|
| 469 |
+
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
|
| 470 |
+
if not len(long_factor) == dim // 2:
|
| 471 |
+
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
|
| 472 |
+
|
| 473 |
+
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
|
| 474 |
+
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
|
| 475 |
+
# unique to longrope (= undesirable)
|
| 476 |
+
if hasattr(config, "original_max_position_embeddings"):
|
| 477 |
+
logger.warning_once(
|
| 478 |
+
"This model has set a `original_max_position_embeddings` field, to be used together with "
|
| 479 |
+
"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
|
| 480 |
+
"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
|
| 481 |
+
"as it is compatible with most model architectures."
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
factor = rope_scaling.get("factor")
|
| 485 |
+
if factor is None:
|
| 486 |
+
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
|
| 487 |
+
elif not isinstance(factor, float) or factor < 1.0:
|
| 488 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 489 |
+
|
| 490 |
+
attention_factor = rope_scaling.get("attention_factor")
|
| 491 |
+
if attention_factor is not None:
|
| 492 |
+
if not isinstance(attention_factor, float) or attention_factor < 0.0:
|
| 493 |
+
logger.warning(
|
| 494 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def _validate_llama3_parameters(config: PretrainedConfig):
|
| 499 |
+
rope_scaling = config.rope_scaling
|
| 500 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 501 |
+
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
| 502 |
+
received_keys = set(rope_scaling.keys())
|
| 503 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
| 504 |
+
|
| 505 |
+
factor = rope_scaling["factor"]
|
| 506 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 507 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 508 |
+
|
| 509 |
+
low_freq_factor = rope_scaling["low_freq_factor"]
|
| 510 |
+
high_freq_factor = rope_scaling["high_freq_factor"]
|
| 511 |
+
if low_freq_factor is None or not isinstance(low_freq_factor, float):
|
| 512 |
+
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
| 513 |
+
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
| 514 |
+
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
| 515 |
+
if high_freq_factor <= low_freq_factor:
|
| 516 |
+
logger.warning(
|
| 517 |
+
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
| 518 |
+
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
|
| 522 |
+
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
|
| 523 |
+
logger.warning(
|
| 524 |
+
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
|
| 525 |
+
f"{original_max_position_embeddings}"
|
| 526 |
+
)
|
| 527 |
+
if original_max_position_embeddings >= config.max_position_embeddings:
|
| 528 |
+
logger.warning(
|
| 529 |
+
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
|
| 530 |
+
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
|
| 535 |
+
ROPE_VALIDATION_FUNCTIONS = {
|
| 536 |
+
"default": _validate_default_rope_parameters,
|
| 537 |
+
"linear": _validate_linear_scaling_rope_parameters,
|
| 538 |
+
"dynamic": _validate_dynamic_scaling_rope_parameters,
|
| 539 |
+
"yarn": _validate_yarn_parameters,
|
| 540 |
+
"longrope": _validate_longrope_parameters,
|
| 541 |
+
"llama3": _validate_llama3_parameters,
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def rope_config_validation(config: PretrainedConfig):
|
| 546 |
+
"""
|
| 547 |
+
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
| 548 |
+
"""
|
| 549 |
+
rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
|
| 550 |
+
if rope_scaling is None:
|
| 551 |
+
return
|
| 552 |
+
|
| 553 |
+
# BC: "rope_type" was originally "type"
|
| 554 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
| 555 |
+
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
| 556 |
+
if validation_fn is not None:
|
| 557 |
+
validation_fn(config)
|
| 558 |
+
else:
|
| 559 |
+
logger.warning(
|
| 560 |
+
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
| 561 |
+
)
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoImageProcessor": "image_processing_dualvitok.DualViTokImageProcessor"
|
| 4 |
+
},
|
| 5 |
+
"do_convert_rgb": true,
|
| 6 |
+
"do_normalize": true,
|
| 7 |
+
"do_rescale": true,
|
| 8 |
+
"do_resize": true,
|
| 9 |
+
"image_mean": [
|
| 10 |
+
0.5,
|
| 11 |
+
0.5,
|
| 12 |
+
0.5
|
| 13 |
+
],
|
| 14 |
+
"image_processor_type": "DualViTokImageProcessor",
|
| 15 |
+
"image_std": [
|
| 16 |
+
0.5,
|
| 17 |
+
0.5,
|
| 18 |
+
0.5
|
| 19 |
+
],
|
| 20 |
+
"max_pixels": 1048576,
|
| 21 |
+
"min_pixels": 1024,
|
| 22 |
+
"resample": 3,
|
| 23 |
+
"rescale_factor": 0.00392156862745098,
|
| 24 |
+
"size": {
|
| 25 |
+
"max_pixels": 1048576,
|
| 26 |
+
"min_pixels": 1024
|
| 27 |
+
},
|
| 28 |
+
"spatial_factor": 16
|
| 29 |
+
}
|
processing_qwen2vit.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""
|
| 21 |
+
Processor class for Qwen2-VL.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from typing import List, Union
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from typing import Unpack
|
| 29 |
+
except ImportError:
|
| 30 |
+
from typing_extensions import Unpack
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 34 |
+
from .image_utils import ImageInput, VideoInput
|
| 35 |
+
from transformers.processing_utils import (
|
| 36 |
+
ProcessingKwargs,
|
| 37 |
+
ProcessorMixin,
|
| 38 |
+
)
|
| 39 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 40 |
+
from transformers.utils import logging
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Qwen2VLProcessorKwargs(ProcessingKwargs, total=False):
|
| 47 |
+
_defaults = {
|
| 48 |
+
"text_kwargs": {
|
| 49 |
+
"padding": False,
|
| 50 |
+
},
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class Qwen2VLProcessor(ProcessorMixin):
|
| 55 |
+
r"""
|
| 56 |
+
Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
|
| 57 |
+
[`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
| 58 |
+
[`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information.
|
| 59 |
+
Args:
|
| 60 |
+
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
| 61 |
+
The image processor is a required input.
|
| 62 |
+
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
| 63 |
+
The tokenizer is a required input.
|
| 64 |
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
| 65 |
+
in a chat into a tokenizable string.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
attributes = ["image_processor", "tokenizer"]
|
| 69 |
+
valid_kwargs = ["chat_template"]
|
| 70 |
+
image_processor_class = "Qwen2VLImageProcessor"
|
| 71 |
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
| 72 |
+
|
| 73 |
+
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
| 74 |
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
| 75 |
+
|
| 76 |
+
def __call__(
|
| 77 |
+
self,
|
| 78 |
+
images: ImageInput = None,
|
| 79 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 80 |
+
videos: VideoInput = None,
|
| 81 |
+
**kwargs: Unpack[Qwen2VLProcessorKwargs],
|
| 82 |
+
) -> BatchFeature:
|
| 83 |
+
"""
|
| 84 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
| 85 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
| 86 |
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
| 87 |
+
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
| 91 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 92 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 93 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
| 94 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 95 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 96 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 97 |
+
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
| 98 |
+
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
| 99 |
+
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
| 100 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 101 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 102 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 103 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 104 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 105 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 109 |
+
|
| 110 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 111 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 112 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 113 |
+
`None`).
|
| 114 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 115 |
+
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
| 116 |
+
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
| 117 |
+
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
| 118 |
+
"""
|
| 119 |
+
output_kwargs = self._merge_kwargs(
|
| 120 |
+
Qwen2VLProcessorKwargs,
|
| 121 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 122 |
+
**kwargs,
|
| 123 |
+
)
|
| 124 |
+
if images is not None:
|
| 125 |
+
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
|
| 126 |
+
image_grid_thw = image_inputs["image_grid_thw"]
|
| 127 |
+
else:
|
| 128 |
+
image_inputs = {}
|
| 129 |
+
image_grid_thw = None
|
| 130 |
+
|
| 131 |
+
if videos is not None:
|
| 132 |
+
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["videos_kwargs"])
|
| 133 |
+
video_grid_thw = videos_inputs["video_grid_thw"]
|
| 134 |
+
else:
|
| 135 |
+
videos_inputs = {}
|
| 136 |
+
video_grid_thw = None
|
| 137 |
+
|
| 138 |
+
if not isinstance(text, list):
|
| 139 |
+
text = [text]
|
| 140 |
+
|
| 141 |
+
if image_grid_thw is not None:
|
| 142 |
+
merge_length = self.image_processor.merge_size**2
|
| 143 |
+
index = 0
|
| 144 |
+
for i in range(len(text)):
|
| 145 |
+
while "<|image_pad|>" in text[i]:
|
| 146 |
+
text[i] = text[i].replace(
|
| 147 |
+
"<|image_pad|>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
|
| 148 |
+
)
|
| 149 |
+
index += 1
|
| 150 |
+
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
| 151 |
+
|
| 152 |
+
if video_grid_thw is not None:
|
| 153 |
+
merge_length = self.image_processor.merge_size**2
|
| 154 |
+
index = 0
|
| 155 |
+
for i in range(len(text)):
|
| 156 |
+
while "<|video_pad|>" in text[i]:
|
| 157 |
+
text[i] = text[i].replace(
|
| 158 |
+
"<|video_pad|>", "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
|
| 159 |
+
)
|
| 160 |
+
index += 1
|
| 161 |
+
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
| 162 |
+
|
| 163 |
+
_ = output_kwargs["text_kwargs"].pop("padding_side", None)
|
| 164 |
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 165 |
+
|
| 166 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
| 167 |
+
|
| 168 |
+
def batch_decode(self, *args, **kwargs):
|
| 169 |
+
"""
|
| 170 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 171 |
+
refer to the docstring of this method for more information.
|
| 172 |
+
"""
|
| 173 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 174 |
+
|
| 175 |
+
def decode(self, *args, **kwargs):
|
| 176 |
+
"""
|
| 177 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 178 |
+
the docstring of this method for more information.
|
| 179 |
+
"""
|
| 180 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def model_input_names(self):
|
| 184 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 185 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 186 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a4da82e1c8624d84fc3120cc25001da7ad348fbcac426cc5c6b50375657c2a86
|
| 3 |
+
size 5401021086
|
sdxl_decoder_pipe.py
ADDED
|
@@ -0,0 +1,901 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modify from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
|
| 2 |
+
import inspect
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from einops import repeat, rearrange
|
| 12 |
+
|
| 13 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 14 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 15 |
+
|
| 16 |
+
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
| 17 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 18 |
+
|
| 19 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 20 |
+
import PIL.Image
|
| 21 |
+
|
| 22 |
+
from diffusers.models.attention_processor import (
|
| 23 |
+
AttnProcessor2_0,
|
| 24 |
+
FusedAttnProcessor2_0,
|
| 25 |
+
XFormersAttnProcessor,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from diffusers.utils import (
|
| 29 |
+
USE_PEFT_BACKEND,
|
| 30 |
+
deprecate,
|
| 31 |
+
is_invisible_watermark_available,
|
| 32 |
+
is_torch_xla_available,
|
| 33 |
+
logging,
|
| 34 |
+
replace_example_docstring,
|
| 35 |
+
scale_lora_layers,
|
| 36 |
+
unscale_lora_layers,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 40 |
+
from diffusers.loaders import (
|
| 41 |
+
FromSingleFileMixin,
|
| 42 |
+
IPAdapterMixin,
|
| 43 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 44 |
+
TextualInversionLoaderMixin,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if is_invisible_watermark_available():
|
| 48 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
| 49 |
+
|
| 50 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 51 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline, \
|
| 52 |
+
retrieve_timesteps, rescale_noise_cfg
|
| 53 |
+
|
| 54 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode
|
| 55 |
+
|
| 56 |
+
if is_torch_xla_available():
|
| 57 |
+
import torch_xla.core.xla_model as xm
|
| 58 |
+
|
| 59 |
+
XLA_AVAILABLE = True
|
| 60 |
+
else:
|
| 61 |
+
XLA_AVAILABLE = False
|
| 62 |
+
|
| 63 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class StableDiffusionXLDecoderPipelineOutput(StableDiffusionXLPipelineOutput):
|
| 68 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 69 |
+
indices_semantic: Optional[torch.Tensor] = None
|
| 70 |
+
indices_pixel: Optional[torch.Tensor] = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def expand_dims_like(x, y):
|
| 74 |
+
while x.dim() != y.dim():
|
| 75 |
+
x = x.unsqueeze(-1)
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class AbstractEmbModel(nn.Module):
|
| 80 |
+
def __init__(self):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self._is_trainable = None
|
| 83 |
+
self._ucg_rate = None
|
| 84 |
+
self._input_key = None
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def is_trainable(self) -> bool:
|
| 88 |
+
return self._is_trainable
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def ucg_rate(self) -> Union[float, torch.Tensor]:
|
| 92 |
+
return self._ucg_rate
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def input_key(self) -> str:
|
| 96 |
+
return self._input_key
|
| 97 |
+
|
| 98 |
+
@is_trainable.setter
|
| 99 |
+
def is_trainable(self, value: bool):
|
| 100 |
+
self._is_trainable = value
|
| 101 |
+
|
| 102 |
+
@ucg_rate.setter
|
| 103 |
+
def ucg_rate(self, value: Union[float, torch.Tensor]):
|
| 104 |
+
self._ucg_rate = value
|
| 105 |
+
|
| 106 |
+
@input_key.setter
|
| 107 |
+
def input_key(self, value: str):
|
| 108 |
+
self._input_key = value
|
| 109 |
+
|
| 110 |
+
@is_trainable.deleter
|
| 111 |
+
def is_trainable(self):
|
| 112 |
+
del self._is_trainable
|
| 113 |
+
|
| 114 |
+
@ucg_rate.deleter
|
| 115 |
+
def ucg_rate(self):
|
| 116 |
+
del self._ucg_rate
|
| 117 |
+
|
| 118 |
+
@input_key.deleter
|
| 119 |
+
def input_key(self):
|
| 120 |
+
del self._input_key
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class DualViTok2ImageEmbedder(AbstractEmbModel):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
image_processor=None,
|
| 127 |
+
vq_model=None,
|
| 128 |
+
device="cuda",
|
| 129 |
+
dtype=torch.float32,
|
| 130 |
+
freeze=True,
|
| 131 |
+
image_size=0,
|
| 132 |
+
resize_factor=1,
|
| 133 |
+
not_bicubic=True,
|
| 134 |
+
return_sequence=False,
|
| 135 |
+
grid_feature_scale=1,
|
| 136 |
+
texture_drop_prob=0,
|
| 137 |
+
semantic_drop_prob=0,
|
| 138 |
+
pixel_channel=32,
|
| 139 |
+
semantic_channel=32,
|
| 140 |
+
):
|
| 141 |
+
super().__init__()
|
| 142 |
+
vq_model.to(device=device, dtype=dtype)
|
| 143 |
+
vq_model.eval()
|
| 144 |
+
|
| 145 |
+
self.processor = image_processor
|
| 146 |
+
|
| 147 |
+
self.model = vq_model
|
| 148 |
+
self.device = device
|
| 149 |
+
if freeze:
|
| 150 |
+
self.freeze()
|
| 151 |
+
|
| 152 |
+
if image_size > 0:
|
| 153 |
+
preprocessor = [
|
| 154 |
+
Resize(image_size) if not_bicubic else Resize(image_size, interpolation=InterpolationMode.BICUBIC)]
|
| 155 |
+
preprocessor += [
|
| 156 |
+
CenterCrop(image_size),
|
| 157 |
+
]
|
| 158 |
+
self.preprocessor = Compose(preprocessor)
|
| 159 |
+
self.image_size = image_size
|
| 160 |
+
self.resize_factor = resize_factor
|
| 161 |
+
self.not_bicubic = not_bicubic
|
| 162 |
+
self.return_sequence = return_sequence
|
| 163 |
+
self.grid_feature_scale = grid_feature_scale
|
| 164 |
+
self.texture_drop_prob = texture_drop_prob
|
| 165 |
+
self.semantic_drop_prob = semantic_drop_prob
|
| 166 |
+
self.pixel_channel = pixel_channel
|
| 167 |
+
self.semantic_channel = semantic_channel
|
| 168 |
+
|
| 169 |
+
def freeze(self):
|
| 170 |
+
self.model = self.model.eval()
|
| 171 |
+
for param in self.parameters():
|
| 172 |
+
param.requires_grad = False
|
| 173 |
+
|
| 174 |
+
def vq_encode(self, image):
|
| 175 |
+
if image.ndim == 5:
|
| 176 |
+
assert image.size(1) == 1
|
| 177 |
+
image = image.squeeze(1)
|
| 178 |
+
bs, _, h, w = image.shape
|
| 179 |
+
|
| 180 |
+
if self.image_size > 0:
|
| 181 |
+
image = self.preprocessor(image)
|
| 182 |
+
else:
|
| 183 |
+
assert self.resize_factor > 0
|
| 184 |
+
preprocessor = Resize((int(h * self.resize_factor), int(w * self.resize_factor))) if self.not_bicubic else \
|
| 185 |
+
Resize((int(h * self.resize_factor), int(w * self.resize_factor)),
|
| 186 |
+
interpolation=InterpolationMode.BICUBIC)
|
| 187 |
+
image = preprocessor(image)
|
| 188 |
+
|
| 189 |
+
inputs = dict(image=image)
|
| 190 |
+
inputs = self.model.get_input(inputs)
|
| 191 |
+
|
| 192 |
+
(quant_semantic, diff_semantic, indices_semantic, target_semantic), \
|
| 193 |
+
(quant_pixel, diff_pixel, indices_pixel) = self.model.encode(**inputs)
|
| 194 |
+
return indices_semantic, indices_pixel
|
| 195 |
+
|
| 196 |
+
def vq_encode_code(self, image):
|
| 197 |
+
(quant_semantic, diff_semantic, indices_semantic, target_semantic), \
|
| 198 |
+
(quant_pixel, diff_pixel, indices_pixel) = self.vq_encode(image)
|
| 199 |
+
return indices_semantic, indices_pixel
|
| 200 |
+
|
| 201 |
+
def vq_decode_code(self, indices_semantic, indices_pixel):
|
| 202 |
+
return self.model.decode_code(indices_semantic, indices_pixel)
|
| 203 |
+
|
| 204 |
+
def forward(self, image, return_indices=False):
|
| 205 |
+
if image.ndim == 5:
|
| 206 |
+
assert image.size(1) == 1
|
| 207 |
+
image = image.squeeze(1)
|
| 208 |
+
bs, _, h, w = image.shape
|
| 209 |
+
|
| 210 |
+
if self.image_size > 0:
|
| 211 |
+
image = self.preprocessor(image)
|
| 212 |
+
else:
|
| 213 |
+
assert self.resize_factor > 0
|
| 214 |
+
preprocessor = Resize((int(h * self.resize_factor), int(w * self.resize_factor))) if self.not_bicubic else \
|
| 215 |
+
Resize((int(h * self.resize_factor), int(w * self.resize_factor)),
|
| 216 |
+
interpolation=InterpolationMode.BICUBIC)
|
| 217 |
+
image = preprocessor(image)
|
| 218 |
+
|
| 219 |
+
inputs = dict(image=image)
|
| 220 |
+
inputs = self.model.get_input(inputs)
|
| 221 |
+
|
| 222 |
+
(quant_semantic, diff_semantic, indices_semantic, target_semantic), \
|
| 223 |
+
(quant_pixel, diff_pixel, indices_pixel) = self.model.encode(**inputs)
|
| 224 |
+
|
| 225 |
+
feature = self.model.merge_quants(quant_semantic, quant_pixel)
|
| 226 |
+
|
| 227 |
+
if self.return_sequence:
|
| 228 |
+
feature = rearrange(feature, 'b c h w -> b h w c')
|
| 229 |
+
_, this_h, this_w, _ = feature.shape
|
| 230 |
+
feature = feature.view(bs, this_w * this_w, -1)
|
| 231 |
+
else:
|
| 232 |
+
feature = feature * self.grid_feature_scale
|
| 233 |
+
|
| 234 |
+
if return_indices:
|
| 235 |
+
return feature, indices_semantic, indices_pixel
|
| 236 |
+
|
| 237 |
+
return feature
|
| 238 |
+
|
| 239 |
+
def encode(self, img):
|
| 240 |
+
return self(img)
|
| 241 |
+
|
| 242 |
+
def indices_to_codes(self, semantic_indices, texture_indices):
|
| 243 |
+
quant_semantic, quant_texture = self.model.indices_to_codes(semantic_indices, texture_indices)
|
| 244 |
+
return self.model.merge_quants(quant_semantic, quant_texture)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class StableDiffusionXLDecoderPipeline(
|
| 248 |
+
DiffusionPipeline,
|
| 249 |
+
StableDiffusionMixin,
|
| 250 |
+
FromSingleFileMixin,
|
| 251 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 252 |
+
TextualInversionLoaderMixin,
|
| 253 |
+
):
|
| 254 |
+
model_cpu_offload_seq = "vq_model_embedder->unet->vae"
|
| 255 |
+
_optional_components = [
|
| 256 |
+
"vq_model_embedder",
|
| 257 |
+
]
|
| 258 |
+
_callback_tensor_inputs = [
|
| 259 |
+
"latents",
|
| 260 |
+
"prompt_embeds",
|
| 261 |
+
"negative_prompt_embeds",
|
| 262 |
+
"add_text_embeds",
|
| 263 |
+
"add_time_ids",
|
| 264 |
+
"negative_pooled_prompt_embeds",
|
| 265 |
+
"negative_add_time_ids",
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
def __init__(
|
| 269 |
+
self,
|
| 270 |
+
vae: AutoencoderKL,
|
| 271 |
+
unet: UNet2DConditionModel,
|
| 272 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 273 |
+
force_zeros_for_empty_prompt: bool = True,
|
| 274 |
+
add_watermarker: Optional[bool] = None,
|
| 275 |
+
vq_image_processor=None,
|
| 276 |
+
vq_model=None,
|
| 277 |
+
):
|
| 278 |
+
super().__init__()
|
| 279 |
+
|
| 280 |
+
self.register_modules(
|
| 281 |
+
vae=vae,
|
| 282 |
+
unet=unet,
|
| 283 |
+
scheduler=scheduler,
|
| 284 |
+
)
|
| 285 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
| 286 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 287 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 288 |
+
|
| 289 |
+
self.default_sample_size = self.unet.config.sample_size
|
| 290 |
+
|
| 291 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
| 292 |
+
|
| 293 |
+
if add_watermarker:
|
| 294 |
+
self.watermark = StableDiffusionXLWatermarker()
|
| 295 |
+
else:
|
| 296 |
+
self.watermark = None
|
| 297 |
+
|
| 298 |
+
self.empty_prompt_embeds = torch.zeros([1, 77, 2048]).to(device=unet.device, dtype=unet.dtype)
|
| 299 |
+
self.empty_pooled_prompt_embeds = torch.zeros([1, 1280]).to(device=unet.device, dtype=unet.dtype)
|
| 300 |
+
self.dualvitok_channels = vq_model.pixel_channel + vq_model.semantic_channel
|
| 301 |
+
|
| 302 |
+
self.resolution_group = ['(1024, 1024)', '(768, 1024)', '(1024, 768)', '(512, 2048)', '(2048, 512)',
|
| 303 |
+
'(640, 1920)', '(1920, 640)', '(768, 1536)', '(1536, 768)', '(768, 1152)',
|
| 304 |
+
'(1152, 768)', '(512, 512)']
|
| 305 |
+
|
| 306 |
+
embedder_kwargs = dict(image_size=0,
|
| 307 |
+
resize_factor=1,
|
| 308 |
+
return_sequence=False,
|
| 309 |
+
grid_feature_scale=1)
|
| 310 |
+
if isinstance(vq_model, DualViTok2ImageEmbedder):
|
| 311 |
+
self.vq_model_embedder = vq_model
|
| 312 |
+
else:
|
| 313 |
+
self.vq_model_embedder = DualViTok2ImageEmbedder(vq_image_processor, vq_model, **embedder_kwargs)
|
| 314 |
+
|
| 315 |
+
def vq_encode(self, image):
|
| 316 |
+
return self.vq_model_embedder.encode(image)
|
| 317 |
+
|
| 318 |
+
def vq_encode_code(self, image):
|
| 319 |
+
return self.vq_model_embedder.vq_encode_code(image)
|
| 320 |
+
|
| 321 |
+
def vq_decode_code(self, *args, **kwargs):
|
| 322 |
+
return self.vq_model_embedder.vq_decode_code(*args, **kwargs)
|
| 323 |
+
|
| 324 |
+
def indices_to_codes(self, *args, **kwargs):
|
| 325 |
+
return self.vq_model_embedder.indices_to_codes(*args, **kwargs)
|
| 326 |
+
|
| 327 |
+
def _get_add_time_ids(
|
| 328 |
+
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None,
|
| 329 |
+
resolution_index=None,
|
| 330 |
+
):
|
| 331 |
+
add_time_ids = [resolution_index] * 6
|
| 332 |
+
|
| 333 |
+
passed_add_embed_dim = (
|
| 334 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
| 335 |
+
)
|
| 336 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
| 337 |
+
|
| 338 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
| 344 |
+
return add_time_ids
|
| 345 |
+
|
| 346 |
+
def check_inputs(
|
| 347 |
+
self,
|
| 348 |
+
height,
|
| 349 |
+
width,
|
| 350 |
+
callback_steps,
|
| 351 |
+
callback_on_step_end_tensor_inputs=None,
|
| 352 |
+
):
|
| 353 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 354 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 355 |
+
|
| 356 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 357 |
+
raise ValueError(
|
| 358 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 359 |
+
f" {type(callback_steps)}."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 363 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 364 |
+
):
|
| 365 |
+
raise ValueError(
|
| 366 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 370 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 371 |
+
shape = (
|
| 372 |
+
batch_size,
|
| 373 |
+
num_channels_latents,
|
| 374 |
+
int(height) // self.vae_scale_factor,
|
| 375 |
+
int(width) // self.vae_scale_factor,
|
| 376 |
+
)
|
| 377 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 378 |
+
raise ValueError(
|
| 379 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 380 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if latents is None:
|
| 384 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 385 |
+
else:
|
| 386 |
+
latents = latents.to(device)
|
| 387 |
+
|
| 388 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 389 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 390 |
+
return latents
|
| 391 |
+
|
| 392 |
+
def upcast_vae(self):
|
| 393 |
+
dtype = self.vae.dtype
|
| 394 |
+
self.vae.to(dtype=torch.float32)
|
| 395 |
+
use_torch_2_0_or_xformers = isinstance(
|
| 396 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
| 397 |
+
(
|
| 398 |
+
AttnProcessor2_0,
|
| 399 |
+
XFormersAttnProcessor,
|
| 400 |
+
FusedAttnProcessor2_0,
|
| 401 |
+
),
|
| 402 |
+
)
|
| 403 |
+
# if xformers or torch_2_0 is used attention block does not need
|
| 404 |
+
# to be in float32 which can save lots of memory
|
| 405 |
+
if use_torch_2_0_or_xformers:
|
| 406 |
+
self.vae.post_quant_conv.to(dtype)
|
| 407 |
+
self.vae.decoder.conv_in.to(dtype)
|
| 408 |
+
self.vae.decoder.mid_block.to(dtype)
|
| 409 |
+
|
| 410 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
| 411 |
+
def get_guidance_scale_embedding(
|
| 412 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
| 413 |
+
) -> torch.Tensor:
|
| 414 |
+
"""
|
| 415 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
w (`torch.Tensor`):
|
| 419 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
| 420 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 421 |
+
Dimension of the embeddings to generate.
|
| 422 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
| 423 |
+
Data type of the generated embeddings.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
| 427 |
+
"""
|
| 428 |
+
assert len(w.shape) == 1
|
| 429 |
+
w = w * 1000.0
|
| 430 |
+
|
| 431 |
+
half_dim = embedding_dim // 2
|
| 432 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 433 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 434 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 435 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 436 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 437 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 438 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
| 439 |
+
return emb
|
| 440 |
+
|
| 441 |
+
@property
|
| 442 |
+
def guidance_scale(self):
|
| 443 |
+
return self._guidance_scale
|
| 444 |
+
|
| 445 |
+
@property
|
| 446 |
+
def guidance_rescale(self):
|
| 447 |
+
return self._guidance_rescale
|
| 448 |
+
|
| 449 |
+
@property
|
| 450 |
+
def clip_skip(self):
|
| 451 |
+
return self._clip_skip
|
| 452 |
+
|
| 453 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 454 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 455 |
+
# corresponds to doing no classifier free guidance.
|
| 456 |
+
@property
|
| 457 |
+
def do_classifier_free_guidance(self):
|
| 458 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
| 459 |
+
|
| 460 |
+
@property
|
| 461 |
+
def cross_attention_kwargs(self):
|
| 462 |
+
return self._cross_attention_kwargs
|
| 463 |
+
|
| 464 |
+
@property
|
| 465 |
+
def denoising_end(self):
|
| 466 |
+
return self._denoising_end
|
| 467 |
+
|
| 468 |
+
@property
|
| 469 |
+
def num_timesteps(self):
|
| 470 |
+
return self._num_timesteps
|
| 471 |
+
|
| 472 |
+
@property
|
| 473 |
+
def interrupt(self):
|
| 474 |
+
return self._interrupt
|
| 475 |
+
|
| 476 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 477 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 478 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 479 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 480 |
+
# and should be between [0, 1]
|
| 481 |
+
|
| 482 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 483 |
+
extra_step_kwargs = {}
|
| 484 |
+
if accepts_eta:
|
| 485 |
+
extra_step_kwargs["eta"] = eta
|
| 486 |
+
|
| 487 |
+
# check if the scheduler accepts generator
|
| 488 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 489 |
+
if accepts_generator:
|
| 490 |
+
extra_step_kwargs["generator"] = generator
|
| 491 |
+
return extra_step_kwargs
|
| 492 |
+
|
| 493 |
+
@torch.no_grad()
|
| 494 |
+
def __call__(
|
| 495 |
+
self,
|
| 496 |
+
vq_indices: Optional[List] = None,
|
| 497 |
+
vq_embeds: Optional[torch.Tensor] = None,
|
| 498 |
+
images: Optional[PipelineImageInput] = None,
|
| 499 |
+
height: Optional[int] = None,
|
| 500 |
+
width: Optional[int] = None,
|
| 501 |
+
num_inference_steps: int = 50,
|
| 502 |
+
timesteps: List[int] = None,
|
| 503 |
+
sigmas: List[float] = None,
|
| 504 |
+
denoising_end: Optional[float] = None,
|
| 505 |
+
guidance_scale: float = 2.0,
|
| 506 |
+
eta: float = 0.0,
|
| 507 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 508 |
+
latents: Optional[torch.Tensor] = None,
|
| 509 |
+
output_type: Optional[str] = "pil",
|
| 510 |
+
return_dict: bool = True,
|
| 511 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 512 |
+
guidance_rescale: float = 0.0,
|
| 513 |
+
original_size: Optional[Tuple[int, int]] = None,
|
| 514 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 515 |
+
target_size: Optional[Tuple[int, int]] = None,
|
| 516 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 517 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 518 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 519 |
+
clip_skip: Optional[int] = None,
|
| 520 |
+
callback_on_step_end: Optional[
|
| 521 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 522 |
+
] = None,
|
| 523 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 524 |
+
**kwargs,
|
| 525 |
+
):
|
| 526 |
+
r"""
|
| 527 |
+
Function invoked when calling the pipeline for generation.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
vq_indices (`Optional[PipelineImageInput]`, *optional*):
|
| 531 |
+
The VQ indices for semantic and pixel tokens. Should be a tuple of (semantic_indices, pixel_indices).
|
| 532 |
+
images (`Optional[PipelineImageInput]`, *optional*):
|
| 533 |
+
Input images in range [-1, 1] as torch.Tensor with shape (batch_size, channels, height, width).
|
| 534 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 535 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 536 |
+
Anything below 512 pixels won't work well for
|
| 537 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 538 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 539 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 540 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 541 |
+
Anything below 512 pixels won't work well for
|
| 542 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 543 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 544 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 545 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 546 |
+
expense of slower inference.
|
| 547 |
+
timesteps (`List[int]`, *optional*):
|
| 548 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 549 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 550 |
+
passed will be used. Must be in descending order.
|
| 551 |
+
sigmas (`List[float]`, *optional*):
|
| 552 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 553 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 554 |
+
will be used.
|
| 555 |
+
denoising_end (`float`, *optional*):
|
| 556 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
| 557 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
| 558 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
| 559 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
| 560 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
| 561 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
| 562 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 563 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 564 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 565 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 566 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 567 |
+
usually at the expense of lower image quality.
|
| 568 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 569 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 570 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 571 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 572 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 573 |
+
to make generation deterministic.
|
| 574 |
+
latents (`torch.Tensor`, *optional*):
|
| 575 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 576 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 577 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 578 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 579 |
+
The output format of the generate image. Choose between
|
| 580 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 581 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 582 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 583 |
+
of a plain tuple.
|
| 584 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 585 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 586 |
+
`self.processor` in
|
| 587 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 588 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 589 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
| 590 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
| 591 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
| 592 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
| 593 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 594 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 595 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
| 596 |
+
explained in section 2.2 of
|
| 597 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 598 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 599 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
| 600 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
| 601 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 602 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 603 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 604 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
| 605 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
| 606 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 607 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 608 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 609 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 610 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 611 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 612 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 613 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 614 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 615 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 616 |
+
|
| 617 |
+
Examples:
|
| 618 |
+
|
| 619 |
+
Returns:
|
| 620 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
| 621 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
| 622 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 623 |
+
"""
|
| 624 |
+
|
| 625 |
+
callback = kwargs.pop("callback", None)
|
| 626 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 627 |
+
|
| 628 |
+
if callback is not None:
|
| 629 |
+
deprecate(
|
| 630 |
+
"callback",
|
| 631 |
+
"1.0.0",
|
| 632 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 633 |
+
)
|
| 634 |
+
if callback_steps is not None:
|
| 635 |
+
deprecate(
|
| 636 |
+
"callback_steps",
|
| 637 |
+
"1.0.0",
|
| 638 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 642 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 643 |
+
|
| 644 |
+
# 0. Default height and width to unet
|
| 645 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 646 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 647 |
+
|
| 648 |
+
original_size = original_size or (height, width)
|
| 649 |
+
target_size = target_size or (height, width)
|
| 650 |
+
|
| 651 |
+
# 1. Check inputs. Raise error if not correct
|
| 652 |
+
self.check_inputs(
|
| 653 |
+
height,
|
| 654 |
+
width,
|
| 655 |
+
callback_steps,
|
| 656 |
+
callback_on_step_end_tensor_inputs,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
self._guidance_scale = guidance_scale
|
| 660 |
+
self._guidance_rescale = guidance_rescale
|
| 661 |
+
self._clip_skip = clip_skip
|
| 662 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 663 |
+
self._denoising_end = denoising_end
|
| 664 |
+
self._interrupt = False
|
| 665 |
+
|
| 666 |
+
# 2. encode vq_embeds
|
| 667 |
+
assert images is not None or vq_indices is not None or vq_embeds is not None
|
| 668 |
+
batch_size = len(images) if images is not None else len(vq_indices[0])
|
| 669 |
+
|
| 670 |
+
if images:
|
| 671 |
+
vq_embeds, indices_semantic, indices_pixel = self.vq_model_embedder(images, return_indices=True)
|
| 672 |
+
elif vq_indices:
|
| 673 |
+
indices_semantic, indices_pixel = vq_indices[0], vq_indices[1]
|
| 674 |
+
vq_embeds = self.vq_model_embedder.indices_to_codes(vq_indices[0], vq_indices[1])
|
| 675 |
+
elif vq_embeds:
|
| 676 |
+
if isinstance(vq_embeds, list):
|
| 677 |
+
vq_embeds = self.vq_model_embedder.merge_quants(vq_embeds)
|
| 678 |
+
indices_semantic, indices_pixel = None, None
|
| 679 |
+
else:
|
| 680 |
+
raise ValueError("No valid input provided")
|
| 681 |
+
|
| 682 |
+
device = self._execution_device
|
| 683 |
+
|
| 684 |
+
# 3. Encode input prompt
|
| 685 |
+
lora_scale = (
|
| 686 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
prompt_embeds = repeat(self.empty_prompt_embeds, '1 l c -> b l c', b=batch_size)
|
| 690 |
+
pooled_prompt_embeds = repeat(self.empty_pooled_prompt_embeds, '1 c -> b c', b=batch_size)
|
| 691 |
+
|
| 692 |
+
negative_prompt_embeds = prompt_embeds
|
| 693 |
+
negative_pooled_prompt_embeds = pooled_prompt_embeds
|
| 694 |
+
|
| 695 |
+
# 4. Prepare timesteps
|
| 696 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 697 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
# 5. Prepare latent variables
|
| 701 |
+
# num_channels_latents = self.unet.config.in_channels
|
| 702 |
+
num_channels_latents = 4
|
| 703 |
+
latents = self.prepare_latents(
|
| 704 |
+
batch_size,
|
| 705 |
+
num_channels_latents,
|
| 706 |
+
height,
|
| 707 |
+
width,
|
| 708 |
+
prompt_embeds.dtype,
|
| 709 |
+
device,
|
| 710 |
+
generator,
|
| 711 |
+
latents,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 715 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 716 |
+
|
| 717 |
+
# 7. Prepare added time ids & embeddings
|
| 718 |
+
add_text_embeds = pooled_prompt_embeds
|
| 719 |
+
text_encoder_projection_dim = 1280
|
| 720 |
+
|
| 721 |
+
resolution = f'({width}, {height})'
|
| 722 |
+
assert resolution in self.resolution_group, f"resolution are not in resolution group. Got {resolution}. Candidates:{self.resolution_group}"
|
| 723 |
+
resolution_index = self.resolution_group.index(resolution)
|
| 724 |
+
# resolution_index = None
|
| 725 |
+
|
| 726 |
+
add_time_ids = self._get_add_time_ids(
|
| 727 |
+
original_size,
|
| 728 |
+
crops_coords_top_left,
|
| 729 |
+
target_size,
|
| 730 |
+
dtype=prompt_embeds.dtype,
|
| 731 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 732 |
+
resolution_index=resolution_index,
|
| 733 |
+
)
|
| 734 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 735 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 736 |
+
negative_original_size,
|
| 737 |
+
negative_crops_coords_top_left,
|
| 738 |
+
negative_target_size,
|
| 739 |
+
dtype=prompt_embeds.dtype,
|
| 740 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 741 |
+
)
|
| 742 |
+
else:
|
| 743 |
+
negative_add_time_ids = add_time_ids
|
| 744 |
+
|
| 745 |
+
if self.do_classifier_free_guidance:
|
| 746 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 747 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 748 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 749 |
+
|
| 750 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 751 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 752 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
|
| 753 |
+
|
| 754 |
+
# 8. Denoising loop
|
| 755 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 756 |
+
|
| 757 |
+
# 8.1 Apply denoising_end
|
| 758 |
+
if (
|
| 759 |
+
self.denoising_end is not None
|
| 760 |
+
and isinstance(self.denoising_end, float)
|
| 761 |
+
and self.denoising_end > 0
|
| 762 |
+
and self.denoising_end < 1
|
| 763 |
+
):
|
| 764 |
+
discrete_timestep_cutoff = int(
|
| 765 |
+
round(
|
| 766 |
+
self.scheduler.config.num_train_timesteps
|
| 767 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
| 768 |
+
)
|
| 769 |
+
)
|
| 770 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
| 771 |
+
timesteps = timesteps[:num_inference_steps]
|
| 772 |
+
|
| 773 |
+
# 9. Optionally get Guidance Scale Embedding
|
| 774 |
+
timestep_cond = None
|
| 775 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 776 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size)
|
| 777 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 778 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 779 |
+
).to(device=device, dtype=latents.dtype)
|
| 780 |
+
|
| 781 |
+
self._num_timesteps = len(timesteps)
|
| 782 |
+
# with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 783 |
+
for i, t in enumerate(timesteps):
|
| 784 |
+
if self.interrupt:
|
| 785 |
+
continue
|
| 786 |
+
|
| 787 |
+
# expand the latents if we are doing classifier free guidance
|
| 788 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 789 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 790 |
+
|
| 791 |
+
vq_embeds = vq_embeds.to(latent_model_input) if vq_embeds.size(
|
| 792 |
+
-1) == latent_model_input.size(
|
| 793 |
+
-1) else \
|
| 794 |
+
torch.nn.functional.interpolate(vq_embeds.to(latent_model_input),
|
| 795 |
+
size=latent_model_input.shape[-2:])
|
| 796 |
+
vq_embeds_input = torch.cat([torch.zeros_like(vq_embeds),
|
| 797 |
+
vq_embeds]) if self.do_classifier_free_guidance else vq_embeds
|
| 798 |
+
|
| 799 |
+
# predict the noise residual
|
| 800 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
| 801 |
+
|
| 802 |
+
latent_model_input = torch.cat([latent_model_input, vq_embeds_input], dim=1)
|
| 803 |
+
noise_pred = self.unet(
|
| 804 |
+
latent_model_input,
|
| 805 |
+
t,
|
| 806 |
+
encoder_hidden_states=prompt_embeds,
|
| 807 |
+
timestep_cond=timestep_cond,
|
| 808 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 809 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 810 |
+
return_dict=False,
|
| 811 |
+
)[0]
|
| 812 |
+
|
| 813 |
+
# perform guidance
|
| 814 |
+
if self.do_classifier_free_guidance:
|
| 815 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 816 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 817 |
+
|
| 818 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 819 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 820 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=self.guidance_rescale)
|
| 821 |
+
|
| 822 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 823 |
+
latents_dtype = latents.dtype
|
| 824 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 825 |
+
if latents.dtype != latents_dtype:
|
| 826 |
+
if torch.backends.mps.is_available():
|
| 827 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 828 |
+
latents = latents.to(latents_dtype)
|
| 829 |
+
|
| 830 |
+
if callback_on_step_end is not None:
|
| 831 |
+
callback_kwargs = {}
|
| 832 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 833 |
+
callback_kwargs[k] = locals()[k]
|
| 834 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 835 |
+
|
| 836 |
+
latents = callback_outputs.pop("latents", latents)
|
| 837 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 838 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
| 839 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
| 840 |
+
|
| 841 |
+
# call the callback, if provided
|
| 842 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 843 |
+
# progress_bar.update()
|
| 844 |
+
if callback is not None and i % callback_steps == 0:
|
| 845 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 846 |
+
callback(step_idx, t, latents)
|
| 847 |
+
|
| 848 |
+
if XLA_AVAILABLE:
|
| 849 |
+
xm.mark_step()
|
| 850 |
+
|
| 851 |
+
if not output_type == "latent":
|
| 852 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 853 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 854 |
+
|
| 855 |
+
if needs_upcasting:
|
| 856 |
+
self.upcast_vae()
|
| 857 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 858 |
+
elif latents.dtype != self.vae.dtype:
|
| 859 |
+
if torch.backends.mps.is_available():
|
| 860 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 861 |
+
self.vae = self.vae.to(latents.dtype)
|
| 862 |
+
|
| 863 |
+
# unscale/denormalize the latents
|
| 864 |
+
# denormalize with the mean and std if available and not None
|
| 865 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
| 866 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
| 867 |
+
if has_latents_mean and has_latents_std:
|
| 868 |
+
latents_mean = (
|
| 869 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
| 870 |
+
)
|
| 871 |
+
latents_std = (
|
| 872 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
| 873 |
+
)
|
| 874 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
| 875 |
+
else:
|
| 876 |
+
latents = latents / self.vae.config.scaling_factor
|
| 877 |
+
|
| 878 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 879 |
+
|
| 880 |
+
# cast back to fp16 if needed
|
| 881 |
+
if needs_upcasting:
|
| 882 |
+
self.vae.to(dtype=torch.float16)
|
| 883 |
+
else:
|
| 884 |
+
image = latents
|
| 885 |
+
|
| 886 |
+
if not output_type == "latent":
|
| 887 |
+
# apply watermark if available
|
| 888 |
+
if self.watermark is not None:
|
| 889 |
+
image = self.watermark.apply_watermark(image)
|
| 890 |
+
|
| 891 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 892 |
+
|
| 893 |
+
# Offload all models
|
| 894 |
+
self.maybe_free_model_hooks()
|
| 895 |
+
|
| 896 |
+
if not return_dict:
|
| 897 |
+
return (image,)
|
| 898 |
+
|
| 899 |
+
return StableDiffusionXLDecoderPipelineOutput(images=image,
|
| 900 |
+
indices_semantic=indices_semantic,
|
| 901 |
+
indices_pixel=indices_pixel)
|