Diffusers
PyTorch
custom_code
huangrh9 commited on
Commit
5ca5652
·
verified ·
1 Parent(s): 3886cb9

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DualViTok
2
+
3
+ <div align="center">
4
+
5
+ <img src="https://illume-unified-mllm.github.io/static/images/logo.png" width="100em"></img>
6
+
7
+ 📄 [Paper](https://arxiv.org/abs/2504.01934) |
8
+ 🌐 [Project-Page](https://illume-unified-mllm.github.io/) |
9
+ 📦 [Github](https://github.com/illume-unified-mllm/ILLUME_plus) |
10
+
11
+ </div>
12
+
13
+ ## Introduction
14
+
15
+ **DualViTok**, Dual Vision Tokenizer, is a dual-branch vision tokenizer designed to capture both deep semantics and fine-grained textures. It is proposed in [ILLUME+](https://arxiv.org/abs/2504.01934). The semantic branch utilizes a pre-trained text-aligned vision encoder for semantic feature extraction, supervised by feature reconstruction loss. In parallel, the pixel branch integrates quantized features from both the semantic encoder and a CNN-based pixel encoder to enhance pixel-level reconstruction. To improve robustness against incorrect token predictions in autoregressive generation, we introduce noise injection during training by randomly perturbing visual tokens. Despite its simplicity, DualViTok is specifically designed for unified models, ensuring both semantic and texture preservation while maintaining robust token decoding.
16
+
17
+ <div align="center">
18
+ <img src="https://illume-unified-mllm.github.io/static/images/tokenizer_framework.png" width="80%"></img>
19
+ </div>
20
+
21
+ ## Quickstart for Autoencoding
22
+ ```python
23
+ from PIL import Image
24
+ import torch
25
+ from transformers import AutoModel, AutoImageProcessor
26
+
27
+ MODEL_HUB = "ILLUME-MLLM/dualvitok/"
28
+
29
+ model = AutoModel.from_pretrained(MODEL_HUB, trust_remote_code=True).eval().cuda()
30
+ processor = AutoImageProcessor.from_pretrained(MODEL_HUB, trust_remote_code=True)
31
+
32
+ # load the diffusion decoder.
33
+ # diffusion_decoder = model.build_sdxl_decoder('ILLUME-MLLM/dualvitok-sdxl-decoder')
34
+
35
+ # TODO: you need to modify the path here
36
+ IMAGE_PATH = "YOUR_IMAGE_PATH"
37
+
38
+ image = Image.open(IMAGE_PATH)
39
+
40
+ image = processor(image, return_tensors="pt")["pixel_values"]
41
+ image = image.unsqueeze(0).cuda()
42
+
43
+ with torch.no_grad():
44
+ (quant_semantic, diff_semantic, indices_semantic, _), \
45
+ (quant_pixel, diff_pixel, indices_pixel) = model.encode(image)
46
+
47
+ recon = model.decode(quant_semantic, quant_pixel)
48
+
49
+ # decode from the codes.
50
+ # recon = model.decode_code(indices_semantic, indices_pixel)
51
+
52
+ print(recon.shape)
53
+ recon_image = processor.postprocess(recon)["pixel_values"][0]
54
+ recon_image.save("recon_image.png")
55
+
56
+ # diffusion decoder only support 11 resolution. Check here `diffusion_decoder.resolution_group`.
57
+ # diffusion_recon = diffusion_decoder(# use vq_indices or vq_embeds
58
+ # vq_indices=(indices_semantic, indices_pixel),
59
+ # vq_embeds=(quant_semantic, quant_pixel),
60
+ # height = height * 2,
61
+ # width = width * 2,
62
+ # num_inference_steps = 50,
63
+ # guidance_scale = 1.5,)
64
+ # diffusion_recon.images[0].save("diffusion_recon_image.png")
65
+ ```
config.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "configuration_dualvitok.DualViTokConfig",
4
+ "AutoModel": "modeling_dualvitok.DualViTok"
5
+ },
6
+ "architectures": [
7
+ "DualViTok"
8
+ ],
9
+ "semantic_encoder": {
10
+ "pretrained_semantic_encoder": "Emova-ollm/qwen2vit600m",
11
+ "z_channels": 32,
12
+ "num_blocks": 4,
13
+ "out_layer": "linear",
14
+ "embed_dim": 1280,
15
+ "target_mlp": "norm"
16
+ },
17
+ "semantic_decoder": {
18
+ "z_channels": 32,
19
+ "num_blocks": 4,
20
+ "embed_dim": 1280,
21
+ "out_layer": "linear_norm",
22
+ "out_channels": 3584
23
+ },
24
+ "semantic_quantizer_type": "simvq",
25
+ "pixel_quantizer_type": "simvq",
26
+ "semantic_quantizer_codebook_size": 32768,
27
+ "pixel_quantizer_codebook_size": 98304,
28
+ "attn_implementation": "eager",
29
+ "pixel_encoder": {
30
+ "codebook_size": 98304,
31
+ "embed_dim": 32,
32
+ "z_channels": 32,
33
+ "double_z": false,
34
+ "in_channels": 3,
35
+ "out_channels": 3,
36
+ "ch": 128,
37
+ "ch_mult": [
38
+ 1,
39
+ 1,
40
+ 2,
41
+ 2,
42
+ 4
43
+ ],
44
+ "num_res_blocks": 2,
45
+ "attn_resolutions": [
46
+ 4
47
+ ],
48
+ "dropout": 0.0,
49
+ "use_dc_up_down_blocks": true
50
+ },
51
+ "pixel_decoder": {
52
+ "codebook_size": 98304,
53
+ "embed_dim": 64,
54
+ "z_channels": 64,
55
+ "double_z": false,
56
+ "in_channels": 3,
57
+ "out_channels": 3,
58
+ "ch": 384,
59
+ "ch_mult": [
60
+ 1,
61
+ 1,
62
+ 2,
63
+ 2,
64
+ 4
65
+ ],
66
+ "num_res_blocks": 2,
67
+ "attn_resolutions": [
68
+ 4
69
+ ],
70
+ "dropout": 0.0,
71
+ "use_dc_up_down_blocks": true
72
+ },
73
+ "torch_dtype": "float16",
74
+ "transformers_version": "4.44.2"
75
+ }
configuration_dualvitok.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ from .configuration_movqgan import MoVQConfig
8
+ from .modeling_rope_utils import rope_config_validation
9
+ from .configuration_qwen2vit import Qwen2VLVisionConfig
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class SemanticEncoderConfig(PretrainedConfig):
15
+ model_type = "DualViTokSemanticEncoder"
16
+
17
+ def __init__(
18
+ self,
19
+ pretrained_semantic_encoder='Emova-ollm/qwen2vit600m',
20
+ z_channels=32,
21
+ num_blocks=4,
22
+ embed_dim=1280,
23
+ out_layer='linear',
24
+ target_mlp='norm',
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.pretrained_semantic_encoder = pretrained_semantic_encoder
29
+ self.z_channels = z_channels
30
+ self.num_blocks = num_blocks
31
+ self.out_layer = out_layer
32
+ self.embed_dim = embed_dim
33
+ self.target_mlp = target_mlp
34
+
35
+
36
+ class SemanticDecoderConfig(PretrainedConfig):
37
+ model_type = "DualViTokSemanticDecoder"
38
+
39
+ def __init__(
40
+ self,
41
+ z_channels=32,
42
+ num_blocks=4,
43
+ embed_dim=1280,
44
+ out_layer='linear_norm',
45
+ out_channels=3584,
46
+ **kwargs
47
+ ):
48
+ super().__init__(**kwargs)
49
+ self.z_channels = z_channels
50
+ self.num_blocks = num_blocks
51
+ self.embed_dim = embed_dim
52
+ self.out_layer = out_layer
53
+ self.out_channels = out_channels
54
+
55
+ class DualViTokConfig(PretrainedConfig):
56
+ r"""
57
+ This is the configuration class to store the configuration of a [`DualViTok`]. It is used to instantiate an video movq
58
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
59
+ defaults will yield a configuration to the VQ model presented in paper.
60
+
61
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
62
+ documentation from [`PretrainedConfig`] for more information.
63
+
64
+
65
+ Args:
66
+ codebook_size (`int`, *optional*, defaults to 32768):
67
+ Codebook size of the VQ model.
68
+ embed_dim (`int`, *optional*, defaults to 4):
69
+ Dimension of the quantized vector in codebook.
70
+ z_channels (`int`, *optional*, defaults to 4):
71
+ Dimension of the output channel of encoder and the input channel of decoder
72
+ double_z (`bool`, *optional*, defaults to False):
73
+ Whether double the output dim of the encoder.
74
+ in_channels (`int`, *optional*, defaults to 3):
75
+ Input channel of encoder.
76
+ out_channels (`int`, *optional*, defaults to 3):
77
+ Output channel of decoder.
78
+ ch (`int`, *optional*, defaults to 256):
79
+ Basic channel number of the intermediate blocks.
80
+ ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
81
+ Channel scaling factor of the intermediate blocks.
82
+ num_res_blocks (`int`, *optional*, defaults to 2):
83
+ Residual block number in each stage.
84
+ attn_resolutions (`List[int]`, *optional*, defaults to 3):
85
+ Stage indices to apply attention.
86
+ dropout (`float`, *optional*, defaults to 0.0):
87
+ Dropout probability.
88
+
89
+ ```python
90
+ >>> from transformers import DualViTok, DualViTokConfig
91
+
92
+ >>> # Initializing a video VQ model of configuration
93
+ >>> configuration = DualViTokConfig()
94
+
95
+ >>> # Initializing a model from the VQ model style configuration
96
+ >>> model = DualViTok(configuration)
97
+
98
+ >>> # Accessing the model configuration
99
+ >>> configuration = model.config
100
+ ```"""
101
+
102
+ model_type = "DualViTok"
103
+
104
+ def __init__(
105
+ self,
106
+ semantic_encoder=None,
107
+ semantic_decoder=None,
108
+ pixel_encoder=None,
109
+ pixel_decoder=None,
110
+ semantic_quantizer_type='simvq',
111
+ pixel_quantizer_type='simvq',
112
+ semantic_quantizer_codebook_size=32768,
113
+ pixel_quantizer_codebook_size=98304,
114
+ attn_implementation='sdpa',
115
+ **kwargs,
116
+ ):
117
+ super().__init__(**kwargs)
118
+ if semantic_encoder is None:
119
+ self.semantic_encoder = SemanticEncoderConfig()
120
+ else:
121
+ self.semantic_encoder = SemanticEncoderConfig(**semantic_encoder)
122
+ if semantic_decoder is None:
123
+ self.semantic_decoder = SemanticEncoderConfig()
124
+ else:
125
+ self.semantic_decoder = SemanticEncoderConfig(**semantic_decoder)
126
+
127
+ self.semantic_quantizer_type = semantic_quantizer_type
128
+ self.pixel_quantizer_type = pixel_quantizer_type
129
+ self.semantic_quantizer_codebook_size = semantic_quantizer_codebook_size
130
+ self.pixel_quantizer_codebook_size = pixel_quantizer_codebook_size
131
+
132
+ if pixel_encoder is None:
133
+ self.pixel_encoder = MoVQConfig()
134
+ else:
135
+ self.pixel_encoder = MoVQConfig(**pixel_encoder)
136
+
137
+ self.pixel_decoder = self.pixel_encoder if pixel_decoder is None else MoVQConfig(**pixel_decoder)
138
+
139
+ self.attn_implementation = attn_implementation
140
+
141
+ @classmethod
142
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], attn_implementation='sdpa', **kwargs) -> "PretrainedConfig":
143
+ cls._set_token_in_kwargs(kwargs)
144
+
145
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
146
+
147
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
148
+ logger.warning(
149
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
150
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
151
+ )
152
+
153
+ return cls.from_dict(config_dict, attn_implementation=attn_implementation, **kwargs)
configuration_movqgan.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ MoVQ model configuration """
2
+
3
+ from typing import List
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ class MoVQConfig(PretrainedConfig):
13
+ r"""
14
+ This is the configuration class to store the configuration of a [`MoVQ`]. It is used to instantiate an video movq
15
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
16
+ defaults will yield a configuration to the VQ model presented in paper.
17
+
18
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
19
+ documentation from [`PretrainedConfig`] for more information.
20
+
21
+
22
+ Args:
23
+ codebook_size (`int`, *optional*, defaults to 32768):
24
+ Codebook size of the VQ model.
25
+ embed_dim (`int`, *optional*, defaults to 4):
26
+ Dimension of the quantized vector in codebook.
27
+ z_channels (`int`, *optional*, defaults to 4):
28
+ Dimension of the output channel of encoder and the input channel of decoder
29
+ double_z (`bool`, *optional*, defaults to False):
30
+ Whether double the output dim of the encoder.
31
+ in_channels (`int`, *optional*, defaults to 3):
32
+ Input channel of encoder.
33
+ out_channels (`int`, *optional*, defaults to 3):
34
+ Output channel of decoder.
35
+ ch (`int`, *optional*, defaults to 256):
36
+ Basic channel number of the intermediate blocks.
37
+ ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
38
+ Channel scaling factor of the intermediate blocks.
39
+ num_res_blocks (`int`, *optional*, defaults to 2):
40
+ Residual block number in each stage.
41
+ attn_resolutions (`List[int]`, *optional*, defaults to 3):
42
+ Stage indices to apply attention.
43
+ dropout (`float`, *optional*, defaults to 0.0):
44
+ Dropout probability.
45
+
46
+ ```python
47
+ >>> from transformers import MoVQ, MoVQConfig
48
+
49
+ >>> # Initializing a video VQ model of configuration
50
+ >>> configuration = MoVQConfig()
51
+
52
+ >>> # Initializing a model from the VQ model style configuration
53
+ >>> model = MoVQModel(configuration)
54
+
55
+ >>> # Accessing the model configuration
56
+ >>> configuration = model.config
57
+ ```"""
58
+
59
+ model_type = "MoVQ"
60
+
61
+ def __init__(
62
+ self,
63
+ codebook_size: int = 32768,
64
+ embed_dim: int = 4,
65
+ z_channels: int = 4,
66
+ double_z: bool = False,
67
+ in_channels: int = 3,
68
+ out_channels: int = 3,
69
+ ch: int = 256,
70
+ ch_mult: List[int] = [1, 2, 2, 4],
71
+ num_res_blocks: int = 2,
72
+ attn_resolutions: List[int] = [3],
73
+ dropout: float = 0.0,
74
+ use_dc_up_down_blocks=False,
75
+ **kwargs,
76
+ ):
77
+ super().__init__(**kwargs)
78
+
79
+ self.codebook_size = codebook_size
80
+ self.embed_dim = embed_dim
81
+ self.z_channels = z_channels
82
+ self.double_z = double_z
83
+ self.in_channels = in_channels
84
+ self.out_channels = out_channels
85
+ self.ch = ch
86
+ self.ch_mult = ch_mult
87
+ self.num_res_blocks = num_res_blocks
88
+ self.attn_resolutions = attn_resolutions
89
+ self.dropout = dropout
90
+ self.use_dc_up_down_blocks = use_dc_up_down_blocks
configuration_qwen2vit.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen2VL model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+ from .modeling_rope_utils import rope_config_validation
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class Qwen2VLVisionConfig(PretrainedConfig):
28
+ model_type = "qwen2_vl"
29
+
30
+ def __init__(
31
+ self,
32
+ depth=32,
33
+ embed_dim=1280,
34
+ hidden_size=3584,
35
+ hidden_act="quick_gelu",
36
+ mlp_ratio=4,
37
+ num_heads=16,
38
+ in_channels=3,
39
+ patch_size=14,
40
+ spatial_merge_size=2,
41
+ temporal_patch_size=2,
42
+ attn_implementation='eager',
43
+ init_weights=False,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(**kwargs)
47
+
48
+ self.depth = depth
49
+ self.embed_dim = embed_dim
50
+ self.hidden_size = hidden_size
51
+ self.hidden_act = hidden_act
52
+ self.mlp_ratio = mlp_ratio
53
+ self.num_heads = num_heads
54
+ self.in_channels = in_channels
55
+ self.patch_size = patch_size
56
+ self.spatial_merge_size = spatial_merge_size
57
+ self.temporal_patch_size = temporal_patch_size
58
+ self.attn_implementation = attn_implementation if attn_implementation else 'eager'
59
+
60
+ self.init_weights = init_weights
61
+
62
+ @classmethod
63
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
64
+ cls._set_token_in_kwargs(kwargs)
65
+
66
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
67
+
68
+ # if config_dict.get("model_type") == "qwen2_vl":
69
+ # config_dict = config_dict["vision_config"]
70
+
71
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
72
+ logger.warning(
73
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
74
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
75
+ )
76
+
77
+ return cls.from_dict(config_dict, **kwargs)
78
+
79
+
80
+ class Qwen2VLConfig(PretrainedConfig):
81
+ r"""
82
+ This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
83
+ Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
84
+ with the defaults will yield a similar configuration to that of
85
+ Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
86
+
87
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
88
+ documentation from [`PretrainedConfig`] for more information.
89
+
90
+
91
+ Args:
92
+ vocab_size (`int`, *optional*, defaults to 152064):
93
+ Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
94
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
95
+ hidden_size (`int`, *optional*, defaults to 8192):
96
+ Dimension of the hidden representations.
97
+ intermediate_size (`int`, *optional*, defaults to 29568):
98
+ Dimension of the MLP representations.
99
+ num_hidden_layers (`int`, *optional*, defaults to 80):
100
+ Number of hidden layers in the Transformer encoder.
101
+ num_attention_heads (`int`, *optional*, defaults to 64):
102
+ Number of attention heads for each attention layer in the Transformer encoder.
103
+ num_key_value_heads (`int`, *optional*, defaults to 8):
104
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
105
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
106
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
107
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
108
+ by meanpooling all the original heads within that group. For more details checkout [this
109
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
110
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
111
+ The non-linear activation function (function or string) in the decoder.
112
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
113
+ The maximum sequence length that this model might ever be used with.
114
+ initializer_range (`float`, *optional*, defaults to 0.02):
115
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
116
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
117
+ The epsilon used by the rms normalization layers.
118
+ use_cache (`bool`, *optional*, defaults to `True`):
119
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
120
+ relevant if `config.is_decoder=True`.
121
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
122
+ Whether the model's input and output word embeddings should be tied.
123
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
124
+ The base period of the RoPE embeddings.
125
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
126
+ Whether to use sliding window attention.
127
+ sliding_window (`int`, *optional*, defaults to 4096):
128
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
129
+ max_window_layers (`int`, *optional*, defaults to 80):
130
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
131
+ attention_dropout (`float`, *optional*, defaults to 0.0):
132
+ The dropout ratio for the attention probabilities.
133
+ vision_config (`Dict`, *optional*):
134
+ The config for the visual encoder initialization.
135
+ rope_scaling (`Dict`, *optional*):
136
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
137
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
138
+ accordingly.
139
+ Expected contents:
140
+ `rope_type` (`str`):
141
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
142
+ 'llama3'], with 'default' being the original RoPE implementation.
143
+ `factor` (`float`, *optional*):
144
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
145
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
146
+ original maximum pre-trained length.
147
+ `original_max_position_embeddings` (`int`, *optional*):
148
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
149
+ pretraining.
150
+ `attention_factor` (`float`, *optional*):
151
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
152
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
153
+ `factor` field to infer the suggested value.
154
+ `beta_fast` (`float`, *optional*):
155
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
156
+ ramp function. If unspecified, it defaults to 32.
157
+ `beta_slow` (`float`, *optional*):
158
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
159
+ ramp function. If unspecified, it defaults to 1.
160
+ `short_factor` (`List[float]`, *optional*):
161
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
162
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
163
+ size divided by the number of attention heads divided by 2
164
+ `long_factor` (`List[float]`, *optional*):
165
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
166
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
167
+ size divided by the number of attention heads divided by 2
168
+ `low_freq_factor` (`float`, *optional*):
169
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
170
+ `high_freq_factor` (`float`, *optional*):
171
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
172
+
173
+ ```python
174
+ >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
175
+
176
+ >>> # Initializing a Qwen2VL style configuration
177
+ >>> configuration = Qwen2VLConfig()
178
+
179
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
180
+ >>> model = Qwen2VLForConditionalGeneration(configuration)
181
+
182
+ >>> # Accessing the model configuration
183
+ >>> configuration = model.config
184
+ ```"""
185
+
186
+ model_type = "qwen2_vl"
187
+ keys_to_ignore_at_inference = ["past_key_values"]
188
+
189
+ def __init__(
190
+ self,
191
+ vocab_size=152064,
192
+ hidden_size=8192,
193
+ intermediate_size=29568,
194
+ num_hidden_layers=80,
195
+ num_attention_heads=64,
196
+ num_key_value_heads=8,
197
+ hidden_act="silu",
198
+ max_position_embeddings=32768,
199
+ initializer_range=0.02,
200
+ rms_norm_eps=1e-05,
201
+ use_cache=True,
202
+ tie_word_embeddings=False,
203
+ rope_theta=1000000.0,
204
+ use_sliding_window=False,
205
+ sliding_window=4096,
206
+ max_window_layers=80,
207
+ attention_dropout=0.0,
208
+ vision_config=None,
209
+ rope_scaling=None,
210
+ **kwargs,
211
+ ):
212
+ if isinstance(vision_config, dict):
213
+ self.vision_config = Qwen2VLVisionConfig(**vision_config)
214
+ elif vision_config is None:
215
+ self.vision_config = Qwen2VLVisionConfig()
216
+
217
+ self.vocab_size = vocab_size
218
+ self.max_position_embeddings = max_position_embeddings
219
+ self.hidden_size = hidden_size
220
+ self.intermediate_size = intermediate_size
221
+ self.num_hidden_layers = num_hidden_layers
222
+ self.num_attention_heads = num_attention_heads
223
+ self.use_sliding_window = use_sliding_window
224
+ self.sliding_window = sliding_window
225
+ self.max_window_layers = max_window_layers
226
+
227
+ # for backward compatibility
228
+ if num_key_value_heads is None:
229
+ num_key_value_heads = num_attention_heads
230
+
231
+ self.num_key_value_heads = num_key_value_heads
232
+ self.hidden_act = hidden_act
233
+ self.initializer_range = initializer_range
234
+ self.rms_norm_eps = rms_norm_eps
235
+ self.use_cache = use_cache
236
+ self.rope_theta = rope_theta
237
+ self.attention_dropout = attention_dropout
238
+ self.rope_scaling = rope_scaling
239
+
240
+ # Validate the correctness of rotary position embeddings parameters
241
+ # BC: if there is a 'type' field, move it to 'rope_type'.
242
+ # and change type from 'mrope' to 'default'
243
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
244
+ if self.rope_scaling["type"] == "mrope":
245
+ self.rope_scaling["type"] = "default"
246
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
247
+ rope_config_validation(self)
248
+
249
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
image_processing_dualvitok.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ from transformers.utils import TensorType, is_vision_available, logging
4
+
5
+ from .image_processing_movqgan import MoVQImageProcessor
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class DualViTokImageProcessor(MoVQImageProcessor):
11
+ r"""
12
+ Constructs a DualViTok image processor that dynamically resizes images based on the original images.
13
+ This image processor is based on MoVQImageProcessor with spatial_factor of 16.
14
+
15
+ Args:
16
+ do_resize (`bool`, *optional*, defaults to `True`):
17
+ Whether to resize the image's (height, width) dimensions.
18
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
19
+ Resampling filter to use when resizing the image.
20
+ do_rescale (`bool`, *optional*, defaults to `True`):
21
+ Whether to rescale the image by the specified scale `rescale_factor`.
22
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
23
+ Scale factor to use if rescaling the image.
24
+ do_normalize (`bool`, *optional*, defaults to `True`):
25
+ Whether to normalize the image.
26
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
27
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
28
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
29
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
30
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
31
+ Whether to convert the image to RGB.
32
+ min_pixels (`int`, *optional*, defaults to `512 * 512`):
33
+ The min pixels of the image to resize the image.
34
+ max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
35
+ The max pixels of the image to resize the image.
36
+ spatial_factor (`int`, *optional*, defautls to 8):
37
+ The spatial downsample factor the image will be downsampled in feature extracting phase
38
+ """
39
+
40
+ model_input_names = ["pixel_values"]
41
+
42
+ def __init__(
43
+ self,
44
+ *args,
45
+ spatial_factor: int = 16,
46
+ **kwargs,
47
+ ) -> None:
48
+ super().__init__(*args, spatial_factor=spatial_factor, **kwargs)
image_processing_movqgan.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for MoVQ."""
2
+
3
+
4
+ import math
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+
9
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
10
+ from transformers.image_transforms import (
11
+ convert_to_rgb,
12
+ resize,
13
+ to_channel_dimension_format,
14
+ )
15
+ from transformers.image_utils import (
16
+ IMAGENET_STANDARD_MEAN,
17
+ IMAGENET_STANDARD_STD,
18
+ ChannelDimension,
19
+ ImageInput,
20
+ PILImageResampling,
21
+ get_image_size,
22
+ infer_channel_dimension_format,
23
+ is_scaled_image,
24
+ make_list_of_images,
25
+ to_numpy_array,
26
+ valid_images,
27
+ validate_preprocess_arguments,
28
+ )
29
+ from transformers.utils import TensorType, is_vision_available, logging
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ if is_vision_available():
36
+ from PIL import Image
37
+
38
+
39
+ def smart_resize(
40
+ height: int, width: int, factor: int = 8, min_pixels: int = 512 * 512, max_pixels: int = 1024 * 1024
41
+ ):
42
+ """Rescales the image so that the following conditions are met:
43
+
44
+ 1. Both dimensions (height and width) are divisible by 'factor'.
45
+
46
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
47
+
48
+ 3. The aspect ratio of the image is maintained as closely as possible.
49
+
50
+ """
51
+ # if height < factor or width < factor:
52
+ # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
53
+ # elif max(height, width) / min(height, width) > 5:
54
+ # raise ValueError(
55
+ # f"absolute aspect ratio must be smaller than 5, got {max(height, width) / min(height, width)}"
56
+ # )
57
+
58
+ h_bar = round(height / factor) * factor
59
+ w_bar = round(width / factor) * factor
60
+ if h_bar * w_bar > max_pixels:
61
+ beta = math.sqrt((height * width) / max_pixels)
62
+ h_bar = math.floor(height / beta / factor) * factor
63
+ w_bar = math.floor(width / beta / factor) * factor
64
+ elif h_bar * w_bar < min_pixels:
65
+ beta = math.sqrt(min_pixels / (height * width))
66
+ h_bar = math.ceil(height * beta / factor) * factor
67
+ w_bar = math.ceil(width * beta / factor) * factor
68
+
69
+ return max(h_bar, factor), max(w_bar, factor)
70
+
71
+
72
+ class MoVQImageProcessor(BaseImageProcessor):
73
+ r"""
74
+ Constructs a MoVQ image processor that dynamically resizes images based on the original images.
75
+
76
+ Args:
77
+ do_resize (`bool`, *optional*, defaults to `True`):
78
+ Whether to resize the image's (height, width) dimensions.
79
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
80
+ Resampling filter to use when resizing the image.
81
+ do_rescale (`bool`, *optional*, defaults to `True`):
82
+ Whether to rescale the image by the specified scale `rescale_factor`.
83
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
84
+ Scale factor to use if rescaling the image.
85
+ do_normalize (`bool`, *optional*, defaults to `True`):
86
+ Whether to normalize the image.
87
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
88
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
89
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
90
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
91
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
92
+ Whether to convert the image to RGB.
93
+ min_pixels (`int`, *optional*, defaults to `512 * 512`):
94
+ The min pixels of the image to resize the image.
95
+ max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
96
+ The max pixels of the image to resize the image.
97
+ spatial_factor (`int`, *optional*, defautls to 8):
98
+ The spatial downsample factor the image will be downsampled in feature extracting phase
99
+ """
100
+
101
+ model_input_names = ["pixel_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ do_resize: bool = True,
106
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
107
+ do_rescale: bool = True,
108
+ rescale_factor: Union[int, float] = 1 / 255,
109
+ do_normalize: bool = True,
110
+ image_mean: Optional[Union[float, List[float]]] = None,
111
+ image_std: Optional[Union[float, List[float]]] = None,
112
+ do_convert_rgb: bool = True,
113
+ min_pixels: int = 32 * 32,
114
+ max_pixels: int = 1024 * 1024,
115
+ spatial_factor: int = 8,
116
+ **kwargs,
117
+ ) -> None:
118
+ super().__init__(**kwargs)
119
+ self.do_resize = do_resize
120
+ self.resample = resample
121
+ self.do_rescale = do_rescale
122
+ self.rescale_factor = rescale_factor
123
+ self.do_normalize = do_normalize
124
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
125
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
126
+ self.min_pixels = min_pixels
127
+ self.max_pixels = max_pixels
128
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
129
+ self.do_convert_rgb = do_convert_rgb
130
+ self.spatial_factor = spatial_factor
131
+
132
+ def _preprocess(
133
+ self,
134
+ images: ImageInput,
135
+ do_resize: Optional[bool] = None,
136
+ resample: PILImageResampling = None,
137
+ do_rescale: Optional[bool] = None,
138
+ rescale_factor: Optional[float] = None,
139
+ do_normalize: Optional[bool] = None,
140
+ image_mean: Optional[Union[float, List[float]]] = None,
141
+ image_std: Optional[Union[float, List[float]]] = None,
142
+ do_convert_rgb: Optional[bool] = None,
143
+ spatial_factor: Optional[int] = None,
144
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
145
+ output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
146
+ ):
147
+ """
148
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
149
+
150
+ Args:
151
+ images (`ImageInput`):
152
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
153
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
154
+ Whether to resize the image.
155
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
156
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
157
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
158
+ Whether to rescale the image.
159
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
160
+ Scale factor to use if rescaling the image.
161
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
162
+ Whether to normalize the image.
163
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
164
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
165
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
166
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
167
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
168
+ Whether to convert the image to RGB.
169
+ spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
170
+ The spatial downsample factor the image will be downsampled in feature extracting phase
171
+ input_data_format (`ChannelDimension` or `str`, *optional*):
172
+ The channel dimension format for the input image. Can be one of:
173
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
174
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
175
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
176
+ output_data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
177
+ The channel dimension format for the output image. Can be one of:
178
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
179
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
180
+ - Unset: Use the channel dimension format of the input image.
181
+ """
182
+ spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
183
+
184
+ images = make_list_of_images(images)
185
+ if do_convert_rgb:
186
+ images = [convert_to_rgb(image) for image in images]
187
+
188
+ # All transformations expect numpy arrays.
189
+ images = [to_numpy_array(image) for image in images]
190
+
191
+ if is_scaled_image(images[0]) and do_rescale:
192
+ logger.warning_once(
193
+ "It looks like you are trying to rescale already rescaled images. If the input"
194
+ "pixel_values.append()images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
195
+ )
196
+
197
+ if input_data_format is None:
198
+ # We assume that all images have the same channel dimension format.
199
+ input_data_format = infer_channel_dimension_format(images[0])
200
+
201
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
202
+ resized_height, resized_width = height, width
203
+ processed_images = []
204
+ for image in images:
205
+ if do_resize:
206
+ resized_height, resized_width = smart_resize(
207
+ height,
208
+ width,
209
+ factor=spatial_factor,
210
+ min_pixels=self.min_pixels,
211
+ max_pixels=self.max_pixels,
212
+ )
213
+ image = resize(
214
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
215
+ )
216
+
217
+ if do_rescale:
218
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
219
+
220
+ if do_normalize:
221
+ image = self.normalize(
222
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
223
+ )
224
+
225
+ image = to_channel_dimension_format(image, output_data_format, input_channel_dim=input_data_format)
226
+ processed_images.append(image)
227
+
228
+ image = np.array(processed_images)
229
+ return image
230
+
231
+ def preprocess(
232
+ self,
233
+ images: ImageInput,
234
+ do_resize: Optional[bool] = None,
235
+ resample: PILImageResampling = None,
236
+ do_rescale: Optional[bool] = None,
237
+ rescale_factor: Optional[float] = None,
238
+ do_normalize: Optional[bool] = None,
239
+ image_mean: Optional[Union[float, List[float]]] = None,
240
+ image_std: Optional[Union[float, List[float]]] = None,
241
+ do_convert_rgb: Optional[bool] = None,
242
+ spatial_factor: Optional[int] = None,
243
+ return_tensors: Optional[Union[str, TensorType]] = None,
244
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
245
+ output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
246
+ ):
247
+ """
248
+ Args:
249
+ images (`ImageInput`):
250
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
251
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
252
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
253
+ Whether to resize the image.
254
+ resample (`int`, *optional*, defaults to `self.resample`):
255
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
256
+ has an effect if `do_resize` is set to `True`.
257
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
258
+ Whether to rescale the image.
259
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
260
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
261
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
262
+ Whether to normalize the image.
263
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
264
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
265
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
266
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
267
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
268
+ Whether to convert the image to RGB.
269
+ spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
270
+ The spatial downsample factor the image will be downsampled in feature extracting phase
271
+ return_tensors (`str` or `TensorType`, *optional*):
272
+ The type of tensors to return. Can be one of:
273
+ - Unset: Return a list of `np.ndarray`.
274
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
275
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
276
+ input_data_format (`ChannelDimension` or `str`, *optional*):
277
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
278
+ from the input image. Can be one of:
279
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
280
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
281
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
282
+ output_data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
283
+ The channel dimension format for the output image. Can be one of:
284
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
285
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
286
+ - Unset: Use the channel dimension format of the input image.
287
+ """
288
+ do_resize = do_resize if do_resize is not None else self.do_resize
289
+ resample = resample if resample is not None else self.resample
290
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
291
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
292
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
293
+ image_mean = image_mean if image_mean is not None else self.image_mean
294
+ image_std = image_std if image_std is not None else self.image_std
295
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
296
+ spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
297
+
298
+ images = make_list_of_images(images)
299
+ if images is None or not valid_images(images):
300
+ raise ValueError(
301
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
302
+ "torch.Tensor, tf.Tensor or jax.ndarray."
303
+ )
304
+
305
+ validate_preprocess_arguments(
306
+ rescale_factor=rescale_factor,
307
+ do_normalize=do_normalize,
308
+ image_mean=image_mean,
309
+ image_std=image_std,
310
+ do_resize=do_resize,
311
+ size=self.size,
312
+ resample=resample,
313
+ )
314
+
315
+ pixel_values = []
316
+ for image in images:
317
+ norm_image = self._preprocess(
318
+ image,
319
+ do_resize=do_resize,
320
+ resample=resample,
321
+ do_rescale=do_rescale,
322
+ rescale_factor=rescale_factor,
323
+ do_normalize=do_normalize,
324
+ image_mean=image_mean,
325
+ image_std=image_std,
326
+ do_convert_rgb=do_convert_rgb,
327
+ spatial_factor=spatial_factor,
328
+ input_data_format=input_data_format,
329
+ output_data_format=output_data_format,
330
+ )
331
+ pixel_values.extend(norm_image)
332
+ pixel_values = np.array(pixel_values)
333
+ data = {"pixel_values": pixel_values}
334
+
335
+ return BatchFeature(data=data, tensor_type=return_tensors)
336
+
337
+ def postprocess(
338
+ self,
339
+ images: ImageInput,
340
+ do_rescale: Optional[bool] = None,
341
+ rescale_factor: Optional[float] = None,
342
+ do_normalize: Optional[bool] = None,
343
+ image_mean: Optional[Union[float, List[float]]] = None,
344
+ image_std: Optional[Union[float, List[float]]] = None,
345
+ return_tensors: Optional[Union[str, TensorType]] = "PIL.Image.Image",
346
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
347
+ ):
348
+ """
349
+ Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess.
350
+ The parameters should be same as in preprocess.
351
+
352
+ Args:
353
+ images (`ImageInput`):
354
+ Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1.
355
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
356
+ Whether to rescale the image.
357
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
358
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
359
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
360
+ Whether to normalize the image.
361
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
362
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
363
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
364
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
365
+ return_tensors (`str` or `TensorType`, *optional*):
366
+ The type of tensors to return. Can be one of:
367
+ - Unset: Return a list of `np.ndarray`.
368
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
369
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
370
+ input_data_format (`ChannelDimension` or `str`, *optional*):
371
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
372
+ from the input image. Can be one of:
373
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
374
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
375
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
376
+ """
377
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
378
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
379
+ rescale_factor = 1 / rescale_factor
380
+
381
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
382
+ image_mean = image_mean if image_mean is not None else self.image_mean
383
+ image_std = image_std if image_std is not None else self.image_std
384
+ image_mean, image_std = self.inverse_meanstd(image_mean, image_std)
385
+
386
+ images = make_list_of_images(images)
387
+ if isinstance(images[0], Image.Image):
388
+ return images if len(images) > 1 else images[0]
389
+
390
+ if input_data_format is None:
391
+ # We assume that all images have the same channel dimension format.
392
+ input_data_format = infer_channel_dimension_format(images[0])
393
+
394
+ pixel_values = []
395
+ for image in images:
396
+ image = to_numpy_array(image)
397
+ if do_normalize:
398
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
399
+
400
+ if do_rescale:
401
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
402
+ image = image.clip(0, 255).astype(np.uint8)
403
+
404
+ if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
405
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
406
+ pixel_values.append(Image.fromarray(image))
407
+ else:
408
+ pixel_values.extend(image)
409
+
410
+ data = {"pixel_values": pixel_values}
411
+ return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
412
+
413
+ return BatchFeature(data=data, tensor_type=return_tensors)
414
+
415
+ def inverse_meanstd(self, image_mean, image_std):
416
+ image_mean = self.to_tuple(image_mean)
417
+ image_std = self.to_tuple(image_std)
418
+
419
+ rev_image_mean = tuple(-m / s for m, s in zip(image_mean, image_std))
420
+ rev_image_std = tuple(1 / s for s in image_std)
421
+
422
+ return rev_image_mean, rev_image_std
423
+
424
+ def to_tuple(self, value, dim=3):
425
+ if isinstance(value, (int, float)):
426
+ return (value,) * dim
427
+
428
+ return tuple(value)
image_processing_qwen2vit.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Image processor class for Qwen2-VL."""
21
+
22
+ import math
23
+ from typing import Dict, List, Optional, Union
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch import nn
28
+
29
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
30
+ from transformers.image_transforms import (
31
+ convert_to_rgb,
32
+ resize,
33
+ to_channel_dimension_format,
34
+ )
35
+ from .image_utils import (
36
+ OPENAI_CLIP_MEAN,
37
+ OPENAI_CLIP_STD,
38
+ ChannelDimension,
39
+ ImageInput,
40
+ PILImageResampling,
41
+ VideoInput,
42
+ get_image_size,
43
+ infer_channel_dimension_format,
44
+ is_scaled_image,
45
+ is_valid_image,
46
+ make_list_of_images,
47
+ to_numpy_array,
48
+ valid_images,
49
+ validate_preprocess_arguments,
50
+ )
51
+ from transformers.utils import TensorType, is_vision_available, logging
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+
57
+ if is_vision_available():
58
+ from PIL import Image
59
+
60
+
61
+ def make_batched_images(images) -> List[List[ImageInput]]:
62
+ """
63
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
64
+
65
+ Args:
66
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
67
+ The input image.
68
+
69
+ Returns:
70
+ list: A list of images.
71
+ """
72
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
73
+ return [img for img_list in images for img in img_list]
74
+
75
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
76
+ return images
77
+
78
+ elif is_valid_image(images):
79
+ return [images]
80
+
81
+ raise ValueError(f"Could not make batched images from {images}")
82
+
83
+
84
+ # Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
85
+ def make_batched_videos(videos) -> List[VideoInput]:
86
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
87
+ return videos
88
+
89
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
90
+ if isinstance(videos[0], Image.Image):
91
+ return [videos]
92
+ elif len(videos[0].shape) == 4:
93
+ return [list(video) for video in videos]
94
+
95
+ elif is_valid_image(videos) and len(videos.shape) == 4:
96
+ return [list(videos)]
97
+
98
+ raise ValueError(f"Could not make batched video from {videos}")
99
+
100
+
101
+ def smart_resize(
102
+ height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
103
+ ):
104
+ """Rescales the image so that the following conditions are met:
105
+
106
+ 1. Both dimensions (height and width) are divisible by 'factor'.
107
+
108
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
109
+
110
+ 3. The aspect ratio of the image is maintained as closely as possible.
111
+
112
+ """
113
+ if height < factor or width < factor:
114
+ # print("height, width", height, width)
115
+ if height < width:
116
+ h_bar = factor
117
+ w_bar = round(width / height * factor)
118
+ else:
119
+ h_bar = round(height / width * factor)
120
+ w_bar = factor
121
+ # print("h_bar, w_bar", h_bar, w_bar)
122
+ height, width = h_bar, w_bar
123
+ # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
124
+ elif max(height, width) / min(height, width) > 200:
125
+ raise ValueError(
126
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
127
+ )
128
+ h_bar = round(height / factor) * factor
129
+ w_bar = round(width / factor) * factor
130
+ if h_bar * w_bar > max_pixels:
131
+ beta = math.sqrt((height * width) / max_pixels)
132
+ h_bar = math.floor(height / beta / factor) * factor
133
+ w_bar = math.floor(width / beta / factor) * factor
134
+ elif h_bar * w_bar < min_pixels:
135
+ beta = math.sqrt(min_pixels / (height * width))
136
+ h_bar = math.ceil(height * beta / factor) * factor
137
+ w_bar = math.ceil(width * beta / factor) * factor
138
+ return h_bar, w_bar
139
+
140
+
141
+ class Qwen2VLImageProcessor(BaseImageProcessor):
142
+ r"""
143
+ Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
144
+
145
+ Args:
146
+ do_resize (`bool`, *optional*, defaults to `True`):
147
+ Whether to resize the image's (height, width) dimensions.
148
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
149
+ Resampling filter to use when resizing the image.
150
+ do_rescale (`bool`, *optional*, defaults to `True`):
151
+ Whether to rescale the image by the specified scale `rescale_factor`.
152
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
153
+ Scale factor to use if rescaling the image.
154
+ do_normalize (`bool`, *optional*, defaults to `True`):
155
+ Whether to normalize the image.
156
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
157
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
158
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
159
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
160
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
161
+ Whether to convert the image to RGB.
162
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
163
+ The min pixels of the image to resize the image.
164
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
165
+ The max pixels of the image to resize the image.
166
+ patch_size (`int`, *optional*, defaults to 14):
167
+ The spacial patch size of the vision encoder.
168
+ temporal_patch_size (`int`, *optional*, defaults to 2):
169
+ The temporal patch size of the vision encoder.
170
+ merge_size (`int`, *optional*, defaults to 2):
171
+ The merge size of the vision encoder to llm encoder.
172
+ """
173
+
174
+ model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
175
+
176
+ def __init__(
177
+ self,
178
+ do_resize: bool = True,
179
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
180
+ do_rescale: bool = True,
181
+ rescale_factor: Union[int, float] = 1 / 255,
182
+ do_normalize: bool = True,
183
+ image_mean: Optional[Union[float, List[float]]] = None,
184
+ image_std: Optional[Union[float, List[float]]] = None,
185
+ do_convert_rgb: bool = True,
186
+ min_pixels: int = 56 * 56,
187
+ max_pixels: int = 28 * 28 * 1280,
188
+ patch_size: int = 14,
189
+ temporal_patch_size: int = 2,
190
+ merge_size: int = 2,
191
+ shifted_patch_tokenize=False,
192
+ **kwargs,
193
+ ) -> None:
194
+ super().__init__(**kwargs)
195
+ self.do_resize = do_resize
196
+ self.resample = resample
197
+ self.do_rescale = do_rescale
198
+ self.rescale_factor = rescale_factor
199
+ self.do_normalize = do_normalize
200
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
201
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
202
+ self.min_pixels = min_pixels
203
+ self.max_pixels = max_pixels
204
+ self.patch_size = patch_size
205
+ self.temporal_patch_size = temporal_patch_size
206
+ self.merge_size = merge_size
207
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
208
+ self.do_convert_rgb = do_convert_rgb
209
+ self.shifted_patch_tokenize = shifted_patch_tokenize
210
+
211
+ def _preprocess(
212
+ self,
213
+ images: Union[ImageInput, VideoInput],
214
+ do_resize: bool = None,
215
+ resample: PILImageResampling = None,
216
+ do_rescale: bool = None,
217
+ rescale_factor: float = None,
218
+ do_normalize: bool = None,
219
+ image_mean: Optional[Union[float, List[float]]] = None,
220
+ image_std: Optional[Union[float, List[float]]] = None,
221
+ do_convert_rgb: bool = None,
222
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
223
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
224
+ ):
225
+ """
226
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
227
+
228
+ Args:
229
+ images (`ImageInput`):
230
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
231
+ vision_info (`List[Dict]`, *optional*):
232
+ Optional list of dictionaries containing additional information about vision inputs.
233
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
234
+ Whether to resize the image.
235
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
236
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
237
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
238
+ Whether to rescale the image.
239
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
240
+ Scale factor to use if rescaling the image.
241
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
242
+ Whether to normalize the image.
243
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
244
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
245
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
246
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
247
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
248
+ Whether to convert the image to RGB.
249
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
250
+ The channel dimension format for the output image. Can be one of:
251
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
252
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
253
+ - Unset: Use the channel dimension format of the input image.
254
+ input_data_format (`ChannelDimension` or `str`, *optional*):
255
+ The channel dimension format for the input image. Can be one of:
256
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
257
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
258
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
259
+ """
260
+ # import pdb; pdb.set_trace()
261
+ # print("images", images)
262
+ # for image in images:
263
+ # print("image", image.size)
264
+ images = make_list_of_images(images)
265
+
266
+ if do_convert_rgb:
267
+ images = [convert_to_rgb(image) for image in images]
268
+
269
+ # All transformations expect numpy arrays.
270
+ images = [to_numpy_array(image) for image in images]
271
+
272
+ if is_scaled_image(images[0]) and do_rescale:
273
+ logger.warning_once(
274
+ "It looks like you are trying to rescale already rescaled images. If the input"
275
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
276
+ )
277
+ if input_data_format is None:
278
+ # We assume that all images have the same channel dimension format.
279
+ input_data_format = infer_channel_dimension_format(images[0])
280
+
281
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
282
+ resized_height, resized_width = height, width
283
+ processed_images = []
284
+ for image in images:
285
+ if do_resize:
286
+ resized_height, resized_width = smart_resize(
287
+ height,
288
+ width,
289
+ factor=self.patch_size * self.merge_size,
290
+ min_pixels=self.min_pixels,
291
+ max_pixels=self.max_pixels,
292
+ )
293
+ image = resize(
294
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
295
+ )
296
+
297
+ if do_rescale:
298
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
299
+
300
+ if do_normalize:
301
+ image = self.normalize(
302
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
303
+ )
304
+
305
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
306
+ processed_images.append(image)
307
+
308
+ patches = np.array(processed_images)
309
+ if data_format == ChannelDimension.LAST:
310
+ patches = patches.transpose(0, 3, 1, 2)
311
+ if patches.shape[0] == 1:
312
+ patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
313
+ channel = patches.shape[1]
314
+ grid_t = patches.shape[0] // self.temporal_patch_size
315
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
316
+ patches = patches.reshape(
317
+ grid_t,
318
+ self.temporal_patch_size,
319
+ channel,
320
+ grid_h // self.merge_size,
321
+ self.merge_size,
322
+ self.patch_size,
323
+ grid_w // self.merge_size,
324
+ self.merge_size,
325
+ self.patch_size,
326
+ )
327
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
328
+ flatten_patches = patches.reshape(
329
+ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
330
+ )
331
+
332
+ return flatten_patches, (grid_t, grid_h, grid_w)
333
+
334
+ def preprocess(
335
+ self,
336
+ images: ImageInput,
337
+ videos: VideoInput = None,
338
+ do_resize: bool = None,
339
+ size: Dict[str, int] = None,
340
+ resample: PILImageResampling = None,
341
+ do_rescale: bool = None,
342
+ rescale_factor: float = None,
343
+ do_normalize: bool = None,
344
+ image_mean: Optional[Union[float, List[float]]] = None,
345
+ image_std: Optional[Union[float, List[float]]] = None,
346
+ do_convert_rgb: bool = None,
347
+ return_tensors: Optional[Union[str, TensorType]] = None,
348
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
349
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
350
+ ):
351
+ """
352
+ Args:
353
+ images (`ImageInput`):
354
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
355
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
356
+ videos (`VideoInput`):
357
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
358
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
359
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
360
+ Whether to resize the image.
361
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
362
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
363
+ the longest edge resized to keep the input aspect ratio.
364
+ resample (`int`, *optional*, defaults to `self.resample`):
365
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
366
+ has an effect if `do_resize` is set to `True`.
367
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
368
+ Whether to rescale the image.
369
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
370
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
371
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
372
+ Whether to normalize the image.
373
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
374
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
375
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
376
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
377
+ `True`.
378
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
379
+ Whether to convert the image to RGB.
380
+ return_tensors (`str` or `TensorType`, *optional*):
381
+ The type of tensors to return. Can be one of:
382
+ - Unset: Return a list of `np.ndarray`.
383
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
384
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
385
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
386
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
387
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
388
+ The channel dimension format for the output image. Can be one of:
389
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
390
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
391
+ - Unset: Use the channel dimension format of the input image.
392
+ input_data_format (`ChannelDimension` or `str`, *optional*):
393
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
394
+ from the input image. Can be one of:
395
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
396
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
397
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
398
+
399
+ """
400
+ do_resize = do_resize if do_resize is not None else self.do_resize
401
+ size = size if size is not None else self.size
402
+ resample = resample if resample is not None else self.resample
403
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
404
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
405
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
406
+ image_mean = image_mean if image_mean is not None else self.image_mean
407
+ image_std = image_std if image_std is not None else self.image_std
408
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
409
+
410
+ if images is not None:
411
+ images = make_batched_images(images)
412
+ if videos is not None:
413
+ videos = make_batched_videos(videos)
414
+
415
+ if images is not None and not valid_images(images):
416
+ raise ValueError(
417
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
418
+ "torch.Tensor, tf.Tensor or jax.ndarray."
419
+ )
420
+
421
+ validate_preprocess_arguments(
422
+ rescale_factor=rescale_factor,
423
+ do_normalize=do_normalize,
424
+ image_mean=image_mean,
425
+ image_std=image_std,
426
+ do_resize=do_resize,
427
+ size=size,
428
+ resample=resample,
429
+ )
430
+
431
+ if images is not None:
432
+ pixel_values, vision_grid_thws = [], []
433
+ for image in images:
434
+ patches, image_grid_thw = self._preprocess(
435
+ image,
436
+ do_resize=do_resize,
437
+ resample=resample,
438
+ do_rescale=do_rescale,
439
+ rescale_factor=rescale_factor,
440
+ do_normalize=do_normalize,
441
+ image_mean=image_mean,
442
+ image_std=image_std,
443
+ data_format=data_format,
444
+ do_convert_rgb=do_convert_rgb,
445
+ input_data_format=input_data_format,
446
+ )
447
+ pixel_values.extend(patches)
448
+ vision_grid_thws.append(image_grid_thw)
449
+ pixel_values = np.array(pixel_values)
450
+ vision_grid_thws = np.array(vision_grid_thws)
451
+ data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
452
+
453
+ if videos is not None:
454
+ pixel_values, vision_grid_thws = [], []
455
+ for images in videos:
456
+ patches, video_grid_thw = self._preprocess(
457
+ images,
458
+ do_resize=do_resize,
459
+ resample=resample,
460
+ do_rescale=do_rescale,
461
+ rescale_factor=rescale_factor,
462
+ do_normalize=do_normalize,
463
+ image_mean=image_mean,
464
+ image_std=image_std,
465
+ data_format=data_format,
466
+ do_convert_rgb=do_convert_rgb,
467
+ input_data_format=input_data_format,
468
+ )
469
+ pixel_values.extend(patches)
470
+ vision_grid_thws.append(video_grid_thw)
471
+ pixel_values = np.array(pixel_values)
472
+ vision_grid_thws = np.array(vision_grid_thws)
473
+ data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
474
+
475
+ return BatchFeature(data=data, tensor_type=return_tensors)
476
+
image_utils.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import base64
17
+ import os
18
+ from io import BytesIO
19
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import requests
23
+ from packaging import version
24
+
25
+
26
+ from transformers.utils import (
27
+ ExplicitEnum,
28
+ is_jax_tensor,
29
+ is_numpy_array,
30
+ is_tf_tensor,
31
+ is_torch_available,
32
+ is_torch_tensor,
33
+ is_torchvision_available,
34
+ is_vision_available,
35
+ logging,
36
+ requires_backends,
37
+ to_numpy,
38
+ )
39
+ from transformers.utils.constants import ( # noqa: F401
40
+ IMAGENET_DEFAULT_MEAN,
41
+ IMAGENET_DEFAULT_STD,
42
+ IMAGENET_STANDARD_MEAN,
43
+ IMAGENET_STANDARD_STD,
44
+ OPENAI_CLIP_MEAN,
45
+ OPENAI_CLIP_STD,
46
+ )
47
+
48
+
49
+ if is_vision_available():
50
+ import PIL.Image
51
+ import PIL.ImageOps
52
+
53
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
54
+ PILImageResampling = PIL.Image.Resampling
55
+ else:
56
+ PILImageResampling = PIL.Image
57
+
58
+ if is_torchvision_available():
59
+ from torchvision.transforms import InterpolationMode
60
+
61
+ pil_torch_interpolation_mapping = {
62
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
63
+ PILImageResampling.BOX: InterpolationMode.BOX,
64
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
65
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
66
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
67
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
68
+ }
69
+
70
+
71
+ if TYPE_CHECKING:
72
+ if is_torch_available():
73
+ import torch
74
+
75
+
76
+ logger = logging.get_logger(__name__)
77
+
78
+
79
+ ImageInput = Union[
80
+ "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
81
+ ] # noqa
82
+
83
+
84
+ VideoInput = Union[
85
+ List["PIL.Image.Image"],
86
+ "np.ndarray",
87
+ "torch.Tensor",
88
+ List["np.ndarray"],
89
+ List["torch.Tensor"],
90
+ List[List["PIL.Image.Image"]],
91
+ List[List["np.ndarrray"]],
92
+ List[List["torch.Tensor"]],
93
+ ] # noqa
94
+
95
+
96
+ class ChannelDimension(ExplicitEnum):
97
+ FIRST = "channels_first"
98
+ LAST = "channels_last"
99
+
100
+
101
+ class AnnotationFormat(ExplicitEnum):
102
+ COCO_DETECTION = "coco_detection"
103
+ COCO_PANOPTIC = "coco_panoptic"
104
+
105
+
106
+ class AnnotionFormat(ExplicitEnum):
107
+ COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
108
+ COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
109
+
110
+
111
+ AnnotationType = Dict[str, Union[int, str, List[Dict]]]
112
+
113
+
114
+ def is_pil_image(img):
115
+ return is_vision_available() and isinstance(img, PIL.Image.Image)
116
+
117
+
118
+ class ImageType(ExplicitEnum):
119
+ PIL = "pillow"
120
+ TORCH = "torch"
121
+ NUMPY = "numpy"
122
+ TENSORFLOW = "tensorflow"
123
+ JAX = "jax"
124
+
125
+
126
+ def get_image_type(image):
127
+ if is_pil_image(image):
128
+ return ImageType.PIL
129
+ if is_torch_tensor(image):
130
+ return ImageType.TORCH
131
+ if is_numpy_array(image):
132
+ return ImageType.NUMPY
133
+ if is_tf_tensor(image):
134
+ return ImageType.TENSORFLOW
135
+ if is_jax_tensor(image):
136
+ return ImageType.JAX
137
+ raise ValueError(f"Unrecognised image type {type(image)}")
138
+
139
+
140
+ def is_valid_image(img):
141
+ return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
142
+
143
+
144
+ def valid_images(imgs):
145
+ # If we have an list of images, make sure every image is valid
146
+ if isinstance(imgs, (list, tuple)):
147
+ for img in imgs:
148
+ if not valid_images(img):
149
+ return False
150
+ # If not a list of tuple, we have been given a single image or batched tensor of images
151
+ elif not is_valid_image(imgs):
152
+ return False
153
+ return True
154
+
155
+
156
+ def is_batched(img):
157
+ if isinstance(img, (list, tuple)):
158
+ return is_valid_image(img[0])
159
+ return False
160
+
161
+
162
+ def is_scaled_image(image: np.ndarray) -> bool:
163
+ """
164
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
165
+ """
166
+ if image.dtype == np.uint8:
167
+ return False
168
+
169
+ # It's possible the image has pixel values in [0, 255] but is of floating type
170
+ return np.min(image) >= 0 and np.max(image) <= 1
171
+
172
+
173
+ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
174
+ """
175
+ Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
176
+ If the input is a batch of images, it is converted to a list of images.
177
+
178
+ Args:
179
+ images (`ImageInput`):
180
+ Image of images to turn into a list of images.
181
+ expected_ndims (`int`, *optional*, defaults to 3):
182
+ Expected number of dimensions for a single input image. If the input image has a different number of
183
+ dimensions, an error is raised.
184
+ """
185
+ if is_batched(images):
186
+ return images
187
+
188
+ # Either the input is a single image, in which case we create a list of length 1
189
+ if isinstance(images, PIL.Image.Image):
190
+ # PIL images are never batched
191
+ return [images]
192
+
193
+ if is_valid_image(images):
194
+ if images.ndim == expected_ndims + 1:
195
+ # Batch of images
196
+ images = list(images)
197
+ elif images.ndim == expected_ndims:
198
+ # Single image
199
+ images = [images]
200
+ else:
201
+ raise ValueError(
202
+ f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
203
+ f" {images.ndim} dimensions."
204
+ )
205
+ return images
206
+ raise ValueError(
207
+ "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
208
+ f"jax.ndarray, but got {type(images)}."
209
+ )
210
+
211
+
212
+ def to_numpy_array(img) -> np.ndarray:
213
+ if not is_valid_image(img):
214
+ raise ValueError(f"Invalid image type: {type(img)}")
215
+
216
+ if is_vision_available() and isinstance(img, PIL.Image.Image):
217
+ return np.array(img)
218
+ return to_numpy(img)
219
+
220
+
221
+ def infer_channel_dimension_format(
222
+ image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
223
+ ) -> ChannelDimension:
224
+ """
225
+ Infers the channel dimension format of `image`.
226
+
227
+ Args:
228
+ image (`np.ndarray`):
229
+ The image to infer the channel dimension of.
230
+ num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
231
+ The number of channels of the image.
232
+
233
+ Returns:
234
+ The channel dimension of the image.
235
+ """
236
+ num_channels = num_channels if num_channels is not None else (1, 3)
237
+ num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
238
+
239
+ if image.ndim == 3:
240
+ first_dim, last_dim = 0, 2
241
+ elif image.ndim == 4:
242
+ first_dim, last_dim = 1, 3
243
+ else:
244
+ raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
245
+
246
+ if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
247
+ logger.warning(
248
+ f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
249
+ )
250
+ return ChannelDimension.FIRST
251
+ elif image.shape[first_dim] in num_channels:
252
+ return ChannelDimension.FIRST
253
+ elif image.shape[last_dim] in num_channels:
254
+ return ChannelDimension.LAST
255
+ raise ValueError("Unable to infer channel dimension format")
256
+
257
+
258
+ def get_channel_dimension_axis(
259
+ image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
260
+ ) -> int:
261
+ """
262
+ Returns the channel dimension axis of the image.
263
+
264
+ Args:
265
+ image (`np.ndarray`):
266
+ The image to get the channel dimension axis of.
267
+ input_data_format (`ChannelDimension` or `str`, *optional*):
268
+ The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
269
+
270
+ Returns:
271
+ The channel dimension axis of the image.
272
+ """
273
+ if input_data_format is None:
274
+ input_data_format = infer_channel_dimension_format(image)
275
+ if input_data_format == ChannelDimension.FIRST:
276
+ return image.ndim - 3
277
+ elif input_data_format == ChannelDimension.LAST:
278
+ return image.ndim - 1
279
+ raise ValueError(f"Unsupported data format: {input_data_format}")
280
+
281
+
282
+ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
283
+ """
284
+ Returns the (height, width) dimensions of the image.
285
+
286
+ Args:
287
+ image (`np.ndarray`):
288
+ The image to get the dimensions of.
289
+ channel_dim (`ChannelDimension`, *optional*):
290
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
291
+
292
+ Returns:
293
+ A tuple of the image's height and width.
294
+ """
295
+ if channel_dim is None:
296
+ channel_dim = infer_channel_dimension_format(image)
297
+
298
+ if channel_dim == ChannelDimension.FIRST:
299
+ return image.shape[-2], image.shape[-1]
300
+ elif channel_dim == ChannelDimension.LAST:
301
+ return image.shape[-3], image.shape[-2]
302
+ else:
303
+ raise ValueError(f"Unsupported data format: {channel_dim}")
304
+
305
+
306
+ def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
307
+ if (
308
+ isinstance(annotation, dict)
309
+ and "image_id" in annotation
310
+ and "annotations" in annotation
311
+ and isinstance(annotation["annotations"], (list, tuple))
312
+ and (
313
+ # an image can have no annotations
314
+ len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
315
+ )
316
+ ):
317
+ return True
318
+ return False
319
+
320
+
321
+ def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
322
+ if (
323
+ isinstance(annotation, dict)
324
+ and "image_id" in annotation
325
+ and "segments_info" in annotation
326
+ and "file_name" in annotation
327
+ and isinstance(annotation["segments_info"], (list, tuple))
328
+ and (
329
+ # an image can have no segments
330
+ len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
331
+ )
332
+ ):
333
+ return True
334
+ return False
335
+
336
+
337
+ def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
338
+ return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
339
+
340
+
341
+ def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
342
+ return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
343
+
344
+
345
+ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
346
+ """
347
+ Loads `image` to a PIL Image.
348
+
349
+ Args:
350
+ image (`str` or `PIL.Image.Image`):
351
+ The image to convert to the PIL Image format.
352
+ timeout (`float`, *optional*):
353
+ The timeout value in seconds for the URL request.
354
+
355
+ Returns:
356
+ `PIL.Image.Image`: A PIL Image.
357
+ """
358
+ requires_backends(load_image, ["vision"])
359
+ if isinstance(image, str):
360
+ if image.startswith("http://") or image.startswith("https://"):
361
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
362
+ # like http_huggingface_co.png
363
+ image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
364
+ elif os.path.isfile(image):
365
+ image = PIL.Image.open(image)
366
+ else:
367
+ if image.startswith("data:image/"):
368
+ image = image.split(",")[1]
369
+
370
+ # Try to load as base64
371
+ try:
372
+ b64 = base64.decodebytes(image.encode())
373
+ image = PIL.Image.open(BytesIO(b64))
374
+ except Exception as e:
375
+ raise ValueError(
376
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
377
+ )
378
+ elif isinstance(image, PIL.Image.Image):
379
+ image = image
380
+ else:
381
+ raise TypeError(
382
+ "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
383
+ )
384
+ image = PIL.ImageOps.exif_transpose(image)
385
+ image = image.convert("RGB")
386
+ return image
387
+
388
+
389
+ def validate_preprocess_arguments(
390
+ do_rescale: Optional[bool] = None,
391
+ rescale_factor: Optional[float] = None,
392
+ do_normalize: Optional[bool] = None,
393
+ image_mean: Optional[Union[float, List[float]]] = None,
394
+ image_std: Optional[Union[float, List[float]]] = None,
395
+ do_pad: Optional[bool] = None,
396
+ size_divisibility: Optional[int] = None,
397
+ do_center_crop: Optional[bool] = None,
398
+ crop_size: Optional[Dict[str, int]] = None,
399
+ do_resize: Optional[bool] = None,
400
+ size: Optional[Dict[str, int]] = None,
401
+ resample: Optional["PILImageResampling"] = None,
402
+ ):
403
+ """
404
+ Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
405
+ Raises `ValueError` if arguments incompatibility is caught.
406
+ Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
407
+ sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
408
+ existing arguments when possible.
409
+
410
+ """
411
+ if do_rescale and rescale_factor is None:
412
+ raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
413
+
414
+ if do_pad and size_divisibility is None:
415
+ # Here, size_divisor might be passed as the value of size
416
+ raise ValueError(
417
+ "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
418
+ )
419
+
420
+ if do_normalize and (image_mean is None or image_std is None):
421
+ raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
422
+
423
+ if do_center_crop and crop_size is None:
424
+ raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
425
+
426
+ if do_resize and (size is None or resample is None):
427
+ raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
428
+
429
+
430
+ # In the future we can add a TF implementation here when we have TF models.
431
+ class ImageFeatureExtractionMixin:
432
+ """
433
+ Mixin that contain utilities for preparing image features.
434
+ """
435
+
436
+ def _ensure_format_supported(self, image):
437
+ if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
438
+ raise ValueError(
439
+ f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
440
+ "`torch.Tensor` are."
441
+ )
442
+
443
+ def to_pil_image(self, image, rescale=None):
444
+ """
445
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
446
+ needed.
447
+
448
+ Args:
449
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
450
+ The image to convert to the PIL Image format.
451
+ rescale (`bool`, *optional*):
452
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
453
+ default to `True` if the image type is a floating type, `False` otherwise.
454
+ """
455
+ self._ensure_format_supported(image)
456
+
457
+ if is_torch_tensor(image):
458
+ image = image.numpy()
459
+
460
+ if isinstance(image, np.ndarray):
461
+ if rescale is None:
462
+ # rescale default to the array being of floating type.
463
+ rescale = isinstance(image.flat[0], np.floating)
464
+ # If the channel as been moved to first dim, we put it back at the end.
465
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
466
+ image = image.transpose(1, 2, 0)
467
+ if rescale:
468
+ image = image * 255
469
+ image = image.astype(np.uint8)
470
+ return PIL.Image.fromarray(image)
471
+ return image
472
+
473
+ def convert_rgb(self, image):
474
+ """
475
+ Converts `PIL.Image.Image` to RGB format.
476
+
477
+ Args:
478
+ image (`PIL.Image.Image`):
479
+ The image to convert.
480
+ """
481
+ self._ensure_format_supported(image)
482
+ if not isinstance(image, PIL.Image.Image):
483
+ return image
484
+
485
+ return image.convert("RGB")
486
+
487
+ def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
488
+ """
489
+ Rescale a numpy image by scale amount
490
+ """
491
+ self._ensure_format_supported(image)
492
+ return image * scale
493
+
494
+ def to_numpy_array(self, image, rescale=None, channel_first=True):
495
+ """
496
+ Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
497
+ dimension.
498
+
499
+ Args:
500
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
501
+ The image to convert to a NumPy array.
502
+ rescale (`bool`, *optional*):
503
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
504
+ default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
505
+ channel_first (`bool`, *optional*, defaults to `True`):
506
+ Whether or not to permute the dimensions of the image to put the channel dimension first.
507
+ """
508
+ self._ensure_format_supported(image)
509
+
510
+ if isinstance(image, PIL.Image.Image):
511
+ image = np.array(image)
512
+
513
+ if is_torch_tensor(image):
514
+ image = image.numpy()
515
+
516
+ rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
517
+
518
+ if rescale:
519
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
520
+
521
+ if channel_first and image.ndim == 3:
522
+ image = image.transpose(2, 0, 1)
523
+
524
+ return image
525
+
526
+ def expand_dims(self, image):
527
+ """
528
+ Expands 2-dimensional `image` to 3 dimensions.
529
+
530
+ Args:
531
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
532
+ The image to expand.
533
+ """
534
+ self._ensure_format_supported(image)
535
+
536
+ # Do nothing if PIL image
537
+ if isinstance(image, PIL.Image.Image):
538
+ return image
539
+
540
+ if is_torch_tensor(image):
541
+ image = image.unsqueeze(0)
542
+ else:
543
+ image = np.expand_dims(image, axis=0)
544
+ return image
545
+
546
+ def normalize(self, image, mean, std, rescale=False):
547
+ """
548
+ Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
549
+ if it's a PIL Image.
550
+
551
+ Args:
552
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
553
+ The image to normalize.
554
+ mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
555
+ The mean (per channel) to use for normalization.
556
+ std (`List[float]` or `np.ndarray` or `torch.Tensor`):
557
+ The standard deviation (per channel) to use for normalization.
558
+ rescale (`bool`, *optional*, defaults to `False`):
559
+ Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
560
+ happen automatically.
561
+ """
562
+ self._ensure_format_supported(image)
563
+
564
+ if isinstance(image, PIL.Image.Image):
565
+ image = self.to_numpy_array(image, rescale=True)
566
+ # If the input image is a PIL image, it automatically gets rescaled. If it's another
567
+ # type it may need rescaling.
568
+ elif rescale:
569
+ if isinstance(image, np.ndarray):
570
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
571
+ elif is_torch_tensor(image):
572
+ image = self.rescale(image.float(), 1 / 255.0)
573
+
574
+ if isinstance(image, np.ndarray):
575
+ if not isinstance(mean, np.ndarray):
576
+ mean = np.array(mean).astype(image.dtype)
577
+ if not isinstance(std, np.ndarray):
578
+ std = np.array(std).astype(image.dtype)
579
+ elif is_torch_tensor(image):
580
+ import torch
581
+
582
+ if not isinstance(mean, torch.Tensor):
583
+ if isinstance(mean, np.ndarray):
584
+ mean = torch.from_numpy(mean)
585
+ else:
586
+ mean = torch.tensor(mean)
587
+ if not isinstance(std, torch.Tensor):
588
+ if isinstance(std, np.ndarray):
589
+ std = torch.from_numpy(std)
590
+ else:
591
+ std = torch.tensor(std)
592
+
593
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
594
+ return (image - mean[:, None, None]) / std[:, None, None]
595
+ else:
596
+ return (image - mean) / std
597
+
598
+ def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
599
+ """
600
+ Resizes `image`. Enforces conversion of input to PIL.Image.
601
+
602
+ Args:
603
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
604
+ The image to resize.
605
+ size (`int` or `Tuple[int, int]`):
606
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
607
+ matched to this.
608
+
609
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
610
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
611
+ this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
612
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
613
+ The filter to user for resampling.
614
+ default_to_square (`bool`, *optional*, defaults to `True`):
615
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
616
+ square (`size`,`size`). If set to `False`, will replicate
617
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
618
+ with support for resizing only the smallest edge and providing an optional `max_size`.
619
+ max_size (`int`, *optional*, defaults to `None`):
620
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
621
+ greater than `max_size` after being resized according to `size`, then the image is resized again so
622
+ that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
623
+ edge may be shorter than `size`. Only used if `default_to_square` is `False`.
624
+
625
+ Returns:
626
+ image: A resized `PIL.Image.Image`.
627
+ """
628
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
629
+
630
+ self._ensure_format_supported(image)
631
+
632
+ if not isinstance(image, PIL.Image.Image):
633
+ image = self.to_pil_image(image)
634
+
635
+ if isinstance(size, list):
636
+ size = tuple(size)
637
+
638
+ if isinstance(size, int) or len(size) == 1:
639
+ if default_to_square:
640
+ size = (size, size) if isinstance(size, int) else (size[0], size[0])
641
+ else:
642
+ width, height = image.size
643
+ # specified size only for the smallest edge
644
+ short, long = (width, height) if width <= height else (height, width)
645
+ requested_new_short = size if isinstance(size, int) else size[0]
646
+
647
+ if short == requested_new_short:
648
+ return image
649
+
650
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
651
+
652
+ if max_size is not None:
653
+ if max_size <= requested_new_short:
654
+ raise ValueError(
655
+ f"max_size = {max_size} must be strictly greater than the requested "
656
+ f"size for the smaller edge size = {size}"
657
+ )
658
+ if new_long > max_size:
659
+ new_short, new_long = int(max_size * new_short / new_long), max_size
660
+
661
+ size = (new_short, new_long) if width <= height else (new_long, new_short)
662
+
663
+ return image.resize(size, resample=resample)
664
+
665
+ def center_crop(self, image, size):
666
+ """
667
+ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
668
+ size given, it will be padded (so the returned result has the size asked).
669
+
670
+ Args:
671
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
672
+ The image to resize.
673
+ size (`int` or `Tuple[int, int]`):
674
+ The size to which crop the image.
675
+
676
+ Returns:
677
+ new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
678
+ height, width).
679
+ """
680
+ self._ensure_format_supported(image)
681
+
682
+ if not isinstance(size, tuple):
683
+ size = (size, size)
684
+
685
+ # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
686
+ if is_torch_tensor(image) or isinstance(image, np.ndarray):
687
+ if image.ndim == 2:
688
+ image = self.expand_dims(image)
689
+ image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
690
+ else:
691
+ image_shape = (image.size[1], image.size[0])
692
+
693
+ top = (image_shape[0] - size[0]) // 2
694
+ bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
695
+ left = (image_shape[1] - size[1]) // 2
696
+ right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
697
+
698
+ # For PIL Images we have a method to crop directly.
699
+ if isinstance(image, PIL.Image.Image):
700
+ return image.crop((left, top, right, bottom))
701
+
702
+ # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
703
+ channel_first = True if image.shape[0] in [1, 3] else False
704
+
705
+ # Transpose (height, width, n_channels) format images
706
+ if not channel_first:
707
+ if isinstance(image, np.ndarray):
708
+ image = image.transpose(2, 0, 1)
709
+ if is_torch_tensor(image):
710
+ image = image.permute(2, 0, 1)
711
+
712
+ # Check if cropped area is within image boundaries
713
+ if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
714
+ return image[..., top:bottom, left:right]
715
+
716
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
717
+ new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
718
+ if isinstance(image, np.ndarray):
719
+ new_image = np.zeros_like(image, shape=new_shape)
720
+ elif is_torch_tensor(image):
721
+ new_image = image.new_zeros(new_shape)
722
+
723
+ top_pad = (new_shape[-2] - image_shape[0]) // 2
724
+ bottom_pad = top_pad + image_shape[0]
725
+ left_pad = (new_shape[-1] - image_shape[1]) // 2
726
+ right_pad = left_pad + image_shape[1]
727
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
728
+
729
+ top += top_pad
730
+ bottom += top_pad
731
+ left += left_pad
732
+ right += left_pad
733
+
734
+ new_image = new_image[
735
+ ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
736
+ ]
737
+
738
+ return new_image
739
+
740
+ def flip_channel_order(self, image):
741
+ """
742
+ Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
743
+ `image` to a NumPy array if it's a PIL Image.
744
+
745
+ Args:
746
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
747
+ The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
748
+ be first.
749
+ """
750
+ self._ensure_format_supported(image)
751
+
752
+ if isinstance(image, PIL.Image.Image):
753
+ image = self.to_numpy_array(image)
754
+
755
+ return image[::-1, :, :]
756
+
757
+ def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
758
+ """
759
+ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
760
+ counter clockwise around its centre.
761
+
762
+ Args:
763
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
764
+ The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
765
+ rotating.
766
+
767
+ Returns:
768
+ image: A rotated `PIL.Image.Image`.
769
+ """
770
+ resample = resample if resample is not None else PIL.Image.NEAREST
771
+
772
+ self._ensure_format_supported(image)
773
+
774
+ if not isinstance(image, PIL.Image.Image):
775
+ image = self.to_pil_image(image)
776
+
777
+ return image.rotate(
778
+ angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
779
+ )
780
+
781
+
782
+ def validate_annotations(
783
+ annotation_format: AnnotationFormat,
784
+ supported_annotation_formats: Tuple[AnnotationFormat, ...],
785
+ annotations: List[Dict],
786
+ ) -> None:
787
+ if annotation_format not in supported_annotation_formats:
788
+ raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
789
+
790
+ if annotation_format is AnnotationFormat.COCO_DETECTION:
791
+ if not valid_coco_detection_annotations(annotations):
792
+ raise ValueError(
793
+ "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
794
+ "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
795
+ "being a list of annotations in the COCO format."
796
+ )
797
+
798
+ if annotation_format is AnnotationFormat.COCO_PANOPTIC:
799
+ if not valid_coco_panoptic_annotations(annotations):
800
+ raise ValueError(
801
+ "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
802
+ "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
803
+ "the latter being a list of annotations in the COCO format."
804
+ )
805
+
806
+
807
+ def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
808
+ unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
809
+ if unused_keys:
810
+ unused_key_str = ", ".join(unused_keys)
811
+ # TODO raise a warning here instead of simply logging?
812
+ logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
modeling_dualvitok.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import math
6
+ from typing import Optional, Tuple, Union, List, Callable
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn import Module
12
+
13
+ from einops import rearrange, repeat, pack, unpack
14
+ from einx import get_at
15
+
16
+ from torch.utils.checkpoint import checkpoint
17
+ from transformers import AutoImageProcessor
18
+ from transformers.modeling_utils import PreTrainedModel, get_parameter_device, get_parameter_dtype
19
+
20
+ from .configuration_dualvitok import DualViTokConfig
21
+ from .modeling_movqgan import MoVQModel, MoVQEncoder, MoVQDecoder, Decoder
22
+
23
+ from .configuration_qwen2vit import Qwen2VLVisionConfig
24
+ from .modeling_qwen2vit import Qwen2VisionTransformerPretrainedModel, \
25
+ VisionRotaryEmbedding, Qwen2VLBatchVisionBlock
26
+
27
+ try:
28
+ import xformers.ops as xops
29
+
30
+ is_xformers_available = True
31
+ except Exception as e:
32
+ is_xformers_available = False
33
+
34
+ if torch.__version__ > "2.1.2":
35
+ IS_SDPA_AVAILABLE = True
36
+ else:
37
+ IS_SDPA_AVAILABLE = False
38
+
39
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
40
+ sys.path.append(cur_dir)
41
+
42
+
43
+ # helper functions
44
+
45
+ def exists(v):
46
+ return v is not None
47
+
48
+
49
+ def identity(t):
50
+ return t
51
+
52
+
53
+ def default(v, d):
54
+ return v if exists(v) else d
55
+
56
+
57
+ def pack_one(t, pattern):
58
+ packed, packed_shape = pack([t], pattern)
59
+
60
+ def inverse(out, inv_pattern=None):
61
+ inv_pattern = default(inv_pattern, pattern)
62
+ out, = unpack(out, packed_shape, inv_pattern)
63
+ return out
64
+
65
+ return packed, inverse
66
+
67
+
68
+ # class
69
+
70
+
71
+ class SimVQ(Module):
72
+ def __init__(
73
+ self,
74
+ dim,
75
+ codebook_size,
76
+ codebook_transform: Module | None = None,
77
+ init_fn: Callable = identity,
78
+ channel_first=True,
79
+ input_to_quantize_commit_loss_weight=0.25,
80
+ commitment_weight=1.,
81
+ frozen_codebook_dim=None # frozen codebook dim could have different dimensions than projection
82
+ ):
83
+ super().__init__()
84
+ self.codebook_size = codebook_size
85
+ self.channel_first = channel_first
86
+
87
+ frozen_codebook_dim = default(frozen_codebook_dim, dim)
88
+ codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
89
+ codebook = init_fn(codebook)
90
+
91
+ # the codebook is actually implicit from a linear layer from frozen gaussian or uniform
92
+
93
+ if not exists(codebook_transform):
94
+ codebook_transform = nn.Linear(frozen_codebook_dim, dim, bias=False)
95
+
96
+ self.code_transform = codebook_transform
97
+
98
+ self.register_buffer('frozen_codebook', codebook)
99
+
100
+ # commit loss weighting - weighing input to quantize a bit less is crucial for it to work
101
+ self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
102
+
103
+ # total commitment loss weight
104
+ self.commitment_weight = commitment_weight
105
+
106
+ @property
107
+ def codebook(self):
108
+ return self.code_transform(self.frozen_codebook)
109
+
110
+ def indices_to_codes(
111
+ self,
112
+ indices
113
+ ):
114
+ implicit_codebook = self.codebook
115
+
116
+ frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
117
+ quantized = self.code_transform(frozen_codes)
118
+
119
+ if self.channel_first:
120
+ quantized = rearrange(quantized, 'b ... d -> b d ...')
121
+
122
+ return quantized
123
+
124
+ def forward(
125
+ self,
126
+ x
127
+ ):
128
+ if self.channel_first:
129
+ x = rearrange(x, 'b d ... -> b ... d')
130
+
131
+ x, inverse_pack = pack_one(x, 'b * d')
132
+
133
+ implicit_codebook = self.codebook
134
+
135
+ with torch.no_grad():
136
+ dist = torch.cdist(x, implicit_codebook)
137
+ indices = dist.argmin(dim=-1)
138
+
139
+ # select codes
140
+
141
+ quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
142
+
143
+ # commit loss and straight through, as was done in the paper
144
+
145
+ commit_loss = (
146
+ F.mse_loss(x.detach(), quantized) +
147
+ F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
148
+ )
149
+
150
+ quantized = (quantized - x).detach() + x
151
+
152
+ quantized = inverse_pack(quantized)
153
+ indices = inverse_pack(indices, 'b *')
154
+
155
+ if self.channel_first:
156
+ quantized = rearrange(quantized, 'b ... d-> b d ...')
157
+
158
+ return quantized, commit_loss * self.commitment_weight, indices
159
+
160
+
161
+ def init_weights(m):
162
+ if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):
163
+ if m.weight is not None:
164
+ nn.init.constant_(m.weight, 1)
165
+ if m.bias is not None:
166
+ nn.init.constant_(m.bias, 0)
167
+ elif isinstance(m, nn.Linear):
168
+ nn.init.xavier_uniform_(m.weight)
169
+ if m.bias is not None:
170
+ nn.init.constant_(m.bias, 0)
171
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) \
172
+ or isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
173
+ w = m.weight.data
174
+ nn.init.xavier_uniform_(w)
175
+ if m.bias is not None:
176
+ nn.init.constant_(m.bias, 0)
177
+ elif isinstance(m, nn.Embedding):
178
+ nn.init.normal_(m.weight, mean=0, std=1)
179
+
180
+
181
+ class ScalingLayerForQwen2ViT:
182
+ def __init__(
183
+ self,
184
+ min_pixels: int = 56 * 56,
185
+ max_pixels: int = 28 * 28 * 1280,
186
+ patch_size: int = 14,
187
+ temporal_patch_size: int = 2,
188
+ merge_size: int = 2,
189
+ **kwargs,
190
+ ) -> None:
191
+ super().__init__(**kwargs)
192
+ OPENAI_CLIP_MEAN = torch.as_tensor([0.48145466, 0.4578275, 0.40821073])[None, :, None, None]
193
+ OPENAI_CLIP_STD = torch.as_tensor([0.26862954, 0.26130258, 0.27577711])[None, :, None, None]
194
+
195
+ self.image_mean = OPENAI_CLIP_MEAN
196
+ self.image_std = OPENAI_CLIP_STD
197
+ self.min_pixels = min_pixels
198
+ self.max_pixels = max_pixels
199
+ self.patch_size = patch_size
200
+ self.temporal_patch_size = temporal_patch_size
201
+ self.merge_size = merge_size
202
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
203
+
204
+ def __call__(self, images):
205
+ if images.ndim == 4:
206
+ images = images.unsqueeze(1)
207
+ batch_size, temporal, channel, height, width = images.shape
208
+
209
+ factor = self.patch_size * self.merge_size
210
+
211
+ resized_height, resized_width = height // factor * factor, width // factor * factor
212
+
213
+ images = (images + 1) / 2 # rescale to [0, 1.]
214
+
215
+ images = torch.nn.functional.interpolate(
216
+ images.flatten(0, 1).float(),
217
+ size=(resized_height, resized_width),
218
+ mode='bicubic',
219
+ align_corners=False,
220
+ antialias=True
221
+ ).to(images.dtype)
222
+
223
+ images = images.clamp(0, 1) # rescale to [0, 1.]
224
+ images = ((images - self.image_mean.to(images)) / self.image_std.to(images))
225
+
226
+ images = rearrange(images, '(b t) c h w -> b t c h w', b=batch_size, t=temporal)
227
+ if temporal == 1:
228
+ images = images.repeat_interleave(self.temporal_patch_size, dim=1)
229
+ temporal = self.temporal_patch_size
230
+
231
+ grid_t = temporal // self.temporal_patch_size
232
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
233
+
234
+ images = images.reshape(
235
+ batch_size * grid_t,
236
+ self.temporal_patch_size,
237
+ channel,
238
+ -1
239
+ )
240
+
241
+ images = rearrange(images, 'b p c n -> b n (c p)')
242
+ images = images.reshape(
243
+ batch_size * grid_t,
244
+ grid_h // self.merge_size,
245
+ self.merge_size,
246
+ self.patch_size,
247
+ grid_w // self.merge_size,
248
+ self.merge_size,
249
+ self.patch_size,
250
+ -1
251
+ )
252
+ images = rearrange(images, 'b h k s1 w l s2 n -> (b h w k l) (n s1 s2)')
253
+
254
+ return dict(image=images, image_grid_thw=torch.as_tensor([[grid_t, grid_h, grid_w] for _ in range(batch_size)]))
255
+
256
+
257
+ class SemanticEncoder(nn.Module):
258
+ def __init__(self,
259
+ semantic_encoder,
260
+ z_channels=4,
261
+ num_blocks=2,
262
+ embed_dim=1280,
263
+ proj_layer='linear',
264
+ attn_implementation='xformers',
265
+ target_mlp='identity',
266
+ ):
267
+ super().__init__()
268
+ self.embed_dim = embed_dim
269
+
270
+ if isinstance(semantic_encoder, str):
271
+ self.model = Qwen2VisionTransformerPretrainedModel.from_pretrained(
272
+ semantic_encoder,
273
+ attn_implementation=attn_implementation
274
+ )
275
+ elif isinstance(semantic_encoder, dict):
276
+ config = Qwen2VLVisionConfig(**semantic_encoder, attn_implementation=attn_implementation)
277
+ self.model = Qwen2VisionTransformerPretrainedModel(config)
278
+ else:
279
+ raise ValueError(f"Invalid semantic_encoder: {semantic_encoder}")
280
+ input_channels = self.model.config.hidden_size
281
+
282
+ for p in self.model.parameters():
283
+ p.requires_grad = False
284
+
285
+ self.proj_in = nn.Conv2d(input_channels, embed_dim, 1, 1) if input_channels != embed_dim else nn.Identity()
286
+
287
+ config = Qwen2VLVisionConfig(depth=num_blocks,
288
+ embed_dim=embed_dim, )
289
+ head_dim = config.embed_dim // config.num_heads
290
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
291
+
292
+ self.blocks = nn.ModuleList(
293
+ [Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
294
+ )
295
+
296
+ if proj_layer == 'norm_linear':
297
+ self.proj_out = nn.Sequential(
298
+ nn.LayerNorm(embed_dim),
299
+ nn.Linear(
300
+ embed_dim,
301
+ z_channels,
302
+ )
303
+ )
304
+ elif proj_layer == 'linear':
305
+ self.proj_out = nn.Sequential(
306
+ nn.Linear(
307
+ embed_dim,
308
+ z_channels,
309
+ )
310
+ )
311
+ elif proj_layer == 'mlp':
312
+ self.proj_out = nn.Sequential(
313
+ nn.Linear(embed_dim, embed_dim),
314
+ nn.Tanh(),
315
+ nn.Linear(embed_dim, z_channels),
316
+ )
317
+ else:
318
+ raise RuntimeError(f"Wrong proj layer. Got {proj_layer}")
319
+
320
+ if target_mlp == 'identity':
321
+ self.target_mlp = nn.Sequential(
322
+ nn.Identity(),
323
+ )
324
+ elif target_mlp == 'norm':
325
+ self.target_mlp = nn.Sequential(
326
+ nn.LayerNorm(input_channels, eps=1e-6, elementwise_affine=False),
327
+ )
328
+ self.init_weight()
329
+
330
+ def init_weight(self):
331
+ self.proj_in.apply(init_weights)
332
+ self.blocks.apply(init_weights)
333
+ self.proj_out.apply(init_weights)
334
+ self.target_mlp.apply(init_weights)
335
+
336
+ def rot_pos_emb(self, grid_thw, max_seq_len):
337
+ pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
338
+ for idx, (t, h, w) in enumerate(grid_thw):
339
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
340
+ hpos_ids = hpos_ids.flatten()
341
+
342
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
343
+ wpos_ids = wpos_ids.flatten()
344
+
345
+ current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
346
+ pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
347
+ max_grid_size = grid_thw[:, 1:].max()
348
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
349
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
350
+ return rotary_pos_emb
351
+
352
+ def forward(self, x, grid_thw):
353
+ x = self.model(x, grid_thw=grid_thw)
354
+
355
+ x = x_target = self.target_mlp(x)
356
+
357
+ x = F.linear(x,
358
+ self.proj_in.weight.view(self.proj_in.weight.shape[0], -1),
359
+ self.proj_in.bias)
360
+
361
+ new_grid_thw = torch.as_tensor([[t, h // 2, w // 2] for t, h, w in grid_thw])
362
+
363
+ seq_lens = [t_i * h_i * w_i for t_i, h_i, w_i in new_grid_thw]
364
+ max_seq_len = max(seq_lens)
365
+
366
+ x = rearrange(x, '(b h w) c -> b (h w) c', h=new_grid_thw[0, 1], w=new_grid_thw[0, 2])
367
+
368
+ rotary_pos_emb = self.rot_pos_emb(new_grid_thw, max_seq_len)
369
+
370
+ for blk in self.blocks:
371
+ x = blk(x, rotary_pos_emb=rotary_pos_emb)
372
+
373
+ x = self.proj_out(x) # [b, max_length, d]
374
+
375
+ t, h, w = new_grid_thw[0]
376
+ b = len(grid_thw)
377
+ x = rearrange(x, 'b (h w) c ->b c h w', b=b, h=h, w=w)
378
+ x_target = rearrange(x_target, '(b h w) c ->b c h w', b=b, h=h, w=w)
379
+ return x, x_target
380
+
381
+
382
+ class SemanticDecoder(nn.Module):
383
+ def __init__(self,
384
+ z_channels=4,
385
+ embed_dim=1280,
386
+ num_blocks=2,
387
+ output_channels=1280,
388
+ attn_implementation='xformers',
389
+ proj_layer='linear_norm'):
390
+ super().__init__()
391
+ self.proj_in = nn.Linear(z_channels, embed_dim)
392
+
393
+ self.output_channels = output_channels
394
+ config = Qwen2VLVisionConfig(depth=num_blocks, embed_dim=embed_dim)
395
+
396
+ self.blocks = nn.ModuleList(
397
+ [Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
398
+ )
399
+ head_dim = config.embed_dim // config.num_heads
400
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
401
+
402
+ if proj_layer == 'norm_linear':
403
+ self.proj_out = nn.Sequential(
404
+ nn.LayerNorm(embed_dim),
405
+ nn.Linear(embed_dim, output_channels),
406
+ )
407
+ elif proj_layer == 'linear':
408
+ self.proj_out = nn.Sequential(
409
+ nn.Linear(embed_dim, output_channels)
410
+ )
411
+ elif proj_layer == 'mlp':
412
+ self.proj_out = nn.Sequential(
413
+ nn.Linear(embed_dim, embed_dim),
414
+ nn.Tanh(),
415
+ nn.Linear(embed_dim, output_channels),
416
+ )
417
+ elif proj_layer == 'linear_norm':
418
+ self.proj_out = nn.Sequential(
419
+ nn.Linear(embed_dim, output_channels),
420
+ nn.LayerNorm(output_channels),
421
+ )
422
+
423
+ self.apply(init_weights)
424
+
425
+ @property
426
+ def last_layer(self):
427
+ return self.proj_out[-1].weight
428
+
429
+ def rot_pos_emb(self, grid_thw, max_seq_len):
430
+ pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
431
+ for idx, (t, h, w) in enumerate(grid_thw):
432
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
433
+ hpos_ids = hpos_ids.flatten()
434
+
435
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
436
+ wpos_ids = wpos_ids.flatten()
437
+
438
+ current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
439
+ pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
440
+ max_grid_size = grid_thw[:, 1:].max()
441
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
442
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
443
+ return rotary_pos_emb
444
+
445
+ def forward(self, z: torch.Tensor):
446
+ x = z
447
+
448
+ b, c, h, w = x.shape
449
+
450
+ x = rearrange(x, 'b c h w -> b (h w) c')
451
+
452
+ grid_thw = torch.as_tensor([[1, h, w] for _ in range(b)])
453
+ seq_lens = [t * h * w for t, h, w in grid_thw]
454
+ max_seq_len = max(seq_lens)
455
+
456
+ x = self.proj_in(x)
457
+
458
+ rotary_pos_emb = self.rot_pos_emb(grid_thw, max_seq_len)
459
+
460
+ for blk in self.blocks:
461
+ x = blk(x, rotary_pos_emb=rotary_pos_emb)
462
+
463
+ x = self.proj_out(x)
464
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
465
+ return x
466
+
467
+
468
+ class DualViTokPretrainModel(PreTrainedModel):
469
+ """
470
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
471
+ models.
472
+ """
473
+
474
+ config_class = DualViTokConfig
475
+ base_model_prefix = "dualvitok"
476
+ main_input_name = "pixel_values"
477
+ _no_split_modules = ["BatchQwen2VLVisionBlock", "MoVQResnetBlock", "MoVQAttnBlock", "MoVQResnetTemporalBlock"]
478
+ _supports_flash_attn_2 = True
479
+ _supports_sdpa = True
480
+ _supports_cache_class = True
481
+ _supports_static_cache = True
482
+
483
+ def _init_weights(self, module):
484
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)):
485
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
486
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
487
+ elif isinstance(module, nn.Linear):
488
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
489
+ if module.bias is not None:
490
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
491
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
492
+ nn.init.uniform_(module.bias, -bound, bound)
493
+ elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
494
+ nn.init.constant_(module.weight, 1)
495
+ nn.init.constant_(module.bias, 0)
496
+
497
+
498
+ class DualViTok(DualViTokPretrainModel):
499
+ def __init__(self, config: DualViTokConfig):
500
+ super().__init__(config)
501
+ self.config = config
502
+
503
+ self._semantic_channel = config.semantic_encoder.z_channels
504
+ self._pixel_channel = config.pixel_encoder.z_channels
505
+
506
+ self.semantic_encoder = SemanticEncoder(
507
+ semantic_encoder=config.semantic_encoder.pretrained_semantic_encoder,
508
+ z_channels=config.semantic_encoder.z_channels,
509
+ num_blocks=config.semantic_encoder.num_blocks,
510
+ embed_dim=config.semantic_encoder.embed_dim,
511
+ proj_layer=config.semantic_encoder.out_layer,
512
+ attn_implementation=config.attn_implementation,
513
+ target_mlp=config.semantic_encoder.target_mlp, )
514
+ self.semantic_decoder = SemanticDecoder(
515
+ z_channels=config.semantic_decoder.z_channels,
516
+ embed_dim=config.semantic_decoder.embed_dim,
517
+ num_blocks=config.semantic_decoder.num_blocks,
518
+ output_channels=config.semantic_decoder.out_channels,
519
+ attn_implementation=config.attn_implementation,
520
+ proj_layer=config.semantic_decoder.out_layer,
521
+ )
522
+
523
+ if config.semantic_quantizer_type.lower() == 'simvq':
524
+ self.semantic_quantizer = SimVQ(
525
+ dim=config.semantic_encoder.z_channels,
526
+ codebook_size=config.semantic_quantizer_codebook_size,
527
+ )
528
+ elif config.semantic_quantizer_type.lower() == 'vq':
529
+ raise NotImplementedError
530
+ self.semantic_quantizer = VQ(
531
+ dim=config.semantic_encoder.z_channels,
532
+ codebook_size=config.semantic_quantizer_codebook_size,
533
+ )
534
+
535
+ self.pixel_encoder = MoVQEncoder(config.pixel_encoder)
536
+ self.pixel_quant_conv = nn.Conv2d(config.pixel_encoder.z_channels, config.pixel_encoder.embed_dim, 1)
537
+
538
+ if config.pixel_quantizer_type.lower() == 'simvq':
539
+ self.pixel_quantizer = SimVQ(
540
+ dim=config.pixel_encoder.z_channels,
541
+ codebook_size=config.pixel_quantizer_codebook_size,
542
+ )
543
+ elif config.pixel_quantizer_type.lower() == 'vq':
544
+ raise NotImplementedError
545
+ self.pixel_quantizer = VQ(
546
+ dim=config.pixel_encoder.z_channels,
547
+ codebook_size=config.pixel_quantizer_codebook_size,
548
+ )
549
+
550
+ self.pixel_post_quant_conv = nn.Conv2d(config.pixel_decoder.embed_dim,
551
+ config.pixel_decoder.z_channels, 1)
552
+
553
+ self.pixel_decoder = MoVQDecoder(config.pixel_decoder)
554
+
555
+ self.scaling_layer = ScalingLayerForQwen2ViT()
556
+
557
+ @property
558
+ def device(self):
559
+ return get_parameter_device(self)
560
+
561
+ @property
562
+ def dtype(self):
563
+ return get_parameter_dtype(self)
564
+
565
+ @property
566
+ def pixel_channel(self):
567
+ return self._pixel_channel
568
+
569
+ @property
570
+ def semantic_channel(self):
571
+ return self._semantic_channel
572
+
573
+ def encode(self, image: torch.FloatTensor):
574
+ scale_output = self.scaling_layer(image)
575
+ image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
576
+
577
+ h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
578
+ quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
579
+
580
+ h_pixel = self.pixel_encoder(image_gen)
581
+ h_pixel = self.pixel_quant_conv(h_pixel)
582
+
583
+ quant_pixel, emb_loss_pixel, info_pixel = self.pixel_quantizer(h_pixel.float())
584
+
585
+ return (quant_semantic, emb_loss_semantic, info_semantic, target_semantic), \
586
+ (quant_pixel, emb_loss_pixel, info_pixel)
587
+
588
+ def encode_code(self, *args, **kwargs):
589
+ (_, _, semantic_indices, _), \
590
+ (_, _, pixel_indices) = self.encode(*args, **kwargs)
591
+ return semantic_indices, pixel_indices
592
+
593
+ def indices_to_codes(self, semantic_indices, pixel_indices):
594
+ quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
595
+ quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
596
+ return quant_semantic, quant_pixel
597
+
598
+ def encode_semantic(self, image: torch.FloatTensor):
599
+ scale_output = self.scaling_layer(image)
600
+ image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
601
+
602
+ h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
603
+ quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
604
+ return quant_semantic, emb_loss_semantic, info_semantic, target_semantic
605
+
606
+ def merge_quants(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor):
607
+ quant_semantic_resized = F.interpolate(
608
+ quant_semantic, quant_pixel.shape[-2:], mode='bicubic'
609
+ ).to(quant_semantic.dtype)
610
+ quant_semantic = quant_semantic_resized
611
+
612
+ quant = torch.cat([quant_semantic, quant_pixel], dim=1)
613
+
614
+ return quant
615
+
616
+ def decode(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor, ):
617
+ quant = self.merge_quants(quant_semantic, quant_pixel)
618
+ quant2 = self.pixel_post_quant_conv(quant)
619
+ x = self.pixel_decoder(quant2, quant)
620
+ return x
621
+
622
+ def decode_code(self, semantic_indices, pixel_indices):
623
+ quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
624
+ quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
625
+ return self.decode(quant_semantic, quant_pixel)
626
+
627
+ def decode_semantic(self, x: List[torch.Tensor]):
628
+ return self.semantic_decoder(x)
629
+
630
+ def forward(self, pixel_values: torch.FloatTensor):
631
+ (quant_semantic, diff_semantic, _, target_semantic), \
632
+ (quant_pixel, diff_pixel, _) = self.encode(pixel_values)
633
+ dec = self.decode(quant_semantic, quant_pixel)
634
+ dec_semantic = self.decode_semantic(quant_semantic)
635
+ return (dec_semantic, diff_semantic, target_semantic), (dec, diff_pixel)
636
+
637
+ def build_sdxl_decoder(self, path='ILLUME-MLLM/dualvitok-sdxl-decoder',
638
+ image_processor=None,
639
+ torch_dtype=torch.float16,
640
+ add_watermarker=False,
641
+ device='cuda',
642
+ ):
643
+ from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
644
+
645
+ if image_processor is None:
646
+ image_processor = AutoImageProcessor.from_pretrained('ILLUME-MLLM/dualvitok', trust_remote_code=True)
647
+
648
+ return StableDiffusionXLDecoderPipeline.from_pretrained(path,
649
+ torch_dtype=torch_dtype,
650
+ add_watermarker=add_watermarker,
651
+ vq_model=self,
652
+ vq_image_processor=image_processor).to(device)
653
+
modeling_movqgan.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ MoVQ model """
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.utils.checkpoint import checkpoint
11
+ from transformers.modeling_utils import PreTrainedModel
12
+
13
+ from .configuration_movqgan import MoVQConfig
14
+
15
+ try:
16
+ import xformers.ops as xops
17
+
18
+ is_xformers_available = True
19
+ except Exception as e:
20
+ is_xformers_available = False
21
+
22
+ if torch.__version__ > "2.1.2":
23
+ IS_SDPA_AVAILABLE = True
24
+ else:
25
+ IS_SDPA_AVAILABLE = False
26
+
27
+
28
+ class MoVQActivation(nn.Module):
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def __call__(self, x: torch.Tensor):
34
+ return x * torch.sigmoid(x)
35
+
36
+
37
+ class MoVQUpsample(nn.Module):
38
+
39
+ def __init__(self, in_channels: int):
40
+ super().__init__()
41
+ self.conv = nn.Conv2d(
42
+ in_channels,
43
+ in_channels,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1,
47
+ )
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ x = F.interpolate(x.float(), scale_factor=2.0, mode="nearest").to(x.dtype)
51
+ x = self.conv(x)
52
+ return x
53
+
54
+
55
+ class DCDownBlock2d(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int = None, downsample: bool = True,
57
+ shortcut: bool = True) -> None:
58
+ super().__init__()
59
+ out_channels = out_channels if out_channels else in_channels
60
+
61
+ self.downsample = downsample
62
+ self.factor = 2
63
+ self.stride = 1 if downsample else 2
64
+ self.group_size = in_channels * self.factor ** 2 // out_channels
65
+ self.shortcut = shortcut
66
+
67
+ out_ratio = self.factor ** 2
68
+ if downsample:
69
+ assert out_channels % out_ratio == 0
70
+ out_channels = out_channels // out_ratio
71
+
72
+ self.conv = nn.Conv2d(
73
+ in_channels,
74
+ out_channels,
75
+ kernel_size=3,
76
+ stride=self.stride,
77
+ padding=1,
78
+ )
79
+
80
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
81
+ x = self.conv(hidden_states)
82
+ if self.downsample:
83
+ x = F.pixel_unshuffle(x, self.factor)
84
+
85
+ if self.shortcut:
86
+ y = F.pixel_unshuffle(hidden_states, self.factor)
87
+ y = y.unflatten(1, (-1, self.group_size))
88
+ y = y.mean(dim=2)
89
+ hidden_states = x + y
90
+ else:
91
+ hidden_states = x
92
+
93
+ return hidden_states # x + y
94
+
95
+
96
+ class DCUpBlock2d(nn.Module):
97
+ def __init__(
98
+ self,
99
+ in_channels: int,
100
+ out_channels: int = None,
101
+ interpolate: bool = False,
102
+ shortcut: bool = True,
103
+ interpolation_mode: str = "nearest",
104
+ ) -> None:
105
+ super().__init__()
106
+ out_channels = out_channels if out_channels else in_channels
107
+
108
+ self.interpolate = interpolate
109
+ self.interpolation_mode = interpolation_mode
110
+ self.shortcut = shortcut
111
+ self.factor = 2
112
+ self.repeats = out_channels * self.factor ** 2 // in_channels
113
+
114
+ out_ratio = self.factor ** 2
115
+
116
+ if not interpolate:
117
+ out_channels = out_channels * out_ratio
118
+
119
+ self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
120
+
121
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
122
+ if self.interpolate:
123
+ x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
124
+ x = self.conv(x)
125
+ else:
126
+ x = self.conv(hidden_states)
127
+ x = F.pixel_shuffle(x, self.factor)
128
+
129
+ if self.shortcut:
130
+ y = hidden_states.repeat_interleave(self.repeats, dim=1)
131
+ y = F.pixel_shuffle(y, self.factor)
132
+ hidden_states = x + y
133
+ else:
134
+ hidden_states = x
135
+
136
+ return hidden_states
137
+
138
+
139
+ class MoVQDownsample(nn.Module):
140
+
141
+ def __init__(self, in_channels: int):
142
+ super().__init__()
143
+ self.conv = nn.Conv2d(
144
+ in_channels,
145
+ in_channels,
146
+ kernel_size=3,
147
+ stride=2,
148
+ padding=0,
149
+ )
150
+
151
+ def forward(self, x: torch.Tensor):
152
+ pad = (0, 1, 0, 1)
153
+ x = F.pad(x, pad, mode="constant", value=0)
154
+ x = self.conv(x)
155
+ return x
156
+
157
+
158
+ class MoVQSpatialNorm(nn.Module):
159
+
160
+ def __init__(
161
+ self,
162
+ f_channels: int,
163
+ zq_channels: int,
164
+ norm_layer: nn.Module = nn.GroupNorm,
165
+ add_conv: bool = False,
166
+ num_groups: int = 32,
167
+ eps: float = 1e-6,
168
+ affine: bool = True,
169
+ ):
170
+ super().__init__()
171
+ self.norm_layer = norm_layer(
172
+ num_channels=f_channels,
173
+ num_groups=num_groups,
174
+ eps=eps,
175
+ affine=affine,
176
+ )
177
+
178
+ self.add_conv = add_conv
179
+ if self.add_conv:
180
+ self.conv = nn.Conv2d(
181
+ zq_channels,
182
+ zq_channels,
183
+ kernel_size=3,
184
+ stride=1,
185
+ padding=1,
186
+ )
187
+
188
+ self.conv_y = nn.Conv2d(
189
+ zq_channels,
190
+ f_channels,
191
+ kernel_size=1,
192
+ stride=1,
193
+ padding=0,
194
+ )
195
+ self.conv_b = nn.Conv2d(
196
+ zq_channels,
197
+ f_channels,
198
+ kernel_size=1,
199
+ stride=1,
200
+ padding=0,
201
+ )
202
+
203
+ def forward(self, x: torch.Tensor, zq: torch.Tensor):
204
+ zq = F.interpolate(zq.float(), size=x.shape[-2:], mode="nearest").to(zq.dtype)
205
+
206
+ if self.add_conv:
207
+ zq = self.conv(zq)
208
+
209
+ x = self.norm_layer(x)
210
+ x = x * self.conv_y(zq) + self.conv_b(zq)
211
+ return x
212
+
213
+
214
+ class MoVQResnetBlock(nn.Module):
215
+
216
+ def __init__(
217
+ self,
218
+ in_channels: int,
219
+ out_channels: Optional[int] = None,
220
+ conv_shortcut: bool = False,
221
+ dropout: float = 0.0,
222
+ zq_ch: Optional[int] = None,
223
+ add_conv: bool = False,
224
+ ):
225
+ super().__init__()
226
+ self.in_channels = in_channels
227
+ out_channels = in_channels if out_channels is None else out_channels
228
+ self.out_channels = out_channels
229
+ self.use_conv_shortcut = conv_shortcut
230
+ self.zq_ch = zq_ch
231
+
232
+ if zq_ch is None:
233
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
234
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
235
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, **norm_kwargs)
236
+ else:
237
+ self.norm1 = MoVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
238
+ self.norm2 = MoVQSpatialNorm(out_channels, zq_ch, add_conv=add_conv)
239
+
240
+ self.conv1 = nn.Conv2d(
241
+ in_channels,
242
+ out_channels,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1,
246
+ )
247
+
248
+ self.dropout = nn.Dropout(dropout)
249
+ self.conv2 = nn.Conv2d(
250
+ out_channels,
251
+ out_channels,
252
+ kernel_size=3,
253
+ stride=1,
254
+ padding=1,
255
+ )
256
+
257
+ self.act = MoVQActivation()
258
+
259
+ if self.in_channels != self.out_channels:
260
+ if self.use_conv_shortcut:
261
+ self.conv_shortcut = nn.Conv2d(
262
+ in_channels,
263
+ out_channels,
264
+ kernel_size=3,
265
+ stride=1,
266
+ padding=1,
267
+ )
268
+ else:
269
+ self.nin_shortcut = nn.Conv2d(
270
+ in_channels,
271
+ out_channels,
272
+ kernel_size=1,
273
+ stride=1,
274
+ padding=0,
275
+ )
276
+
277
+ def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
278
+ norm_args = tuple() if self.zq_ch is None else (zq,)
279
+
280
+ h = self.norm1(x, *norm_args)
281
+ h = self.act(h)
282
+ h = self.conv1(h)
283
+
284
+ h = self.norm2(h, *norm_args)
285
+ h = self.act(h)
286
+ h = self.dropout(h)
287
+ h = self.conv2(h)
288
+
289
+ if self.in_channels != self.out_channels:
290
+ if self.use_conv_shortcut:
291
+ x = self.conv_shortcut(x)
292
+ else:
293
+ x = self.nin_shortcut(x)
294
+
295
+ return x + h
296
+
297
+
298
+ class MoVQAttnBlock(nn.Module):
299
+
300
+ def __init__(
301
+ self,
302
+ in_channels: int,
303
+ zq_ch: Optional[int] = None,
304
+ add_conv: bool = False,
305
+ num_heads=1,
306
+ ):
307
+ super().__init__()
308
+ self.in_channels = in_channels
309
+ self.zq_ch = zq_ch
310
+ self.num_heads = num_heads
311
+
312
+ if zq_ch is None:
313
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
314
+ self.norm = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
315
+ else:
316
+ self.norm = MoVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
317
+
318
+ self.q = nn.Conv2d(
319
+ in_channels,
320
+ in_channels,
321
+ kernel_size=1,
322
+ stride=1,
323
+ padding=0,
324
+ )
325
+ self.k = nn.Conv2d(
326
+ in_channels,
327
+ in_channels,
328
+ kernel_size=1,
329
+ stride=1,
330
+ padding=0,
331
+ )
332
+ self.v = nn.Conv2d(
333
+ in_channels,
334
+ in_channels,
335
+ kernel_size=1,
336
+ stride=1,
337
+ padding=0,
338
+ )
339
+ self.proj_out = nn.Conv2d(
340
+ in_channels,
341
+ in_channels,
342
+ kernel_size=1,
343
+ stride=1,
344
+ padding=0,
345
+ )
346
+
347
+ def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
348
+ # x: [b, c1, h1, w1]
349
+ # zq: [b, c2, h2, w2]
350
+ # attention_mask: [b, 1, h3, w3]
351
+ norm_args = tuple() if self.zq_ch is None else (zq,)
352
+
353
+ # if context is not None:
354
+ # context = F.interpolate(context.float(), size=x.shape[-2:], mode="nearest").to(context.dtype)
355
+ # x = x + self.conv_context(context)
356
+
357
+ nx = self.norm(x, *norm_args)
358
+ q = self.q(nx)
359
+ k = self.k(nx)
360
+ v = self.v(nx)
361
+
362
+ b, c, h, w = q.shape
363
+ if is_xformers_available:
364
+ # If xformers is available, create attn_bias for xops.memory_efficient_attention.
365
+ attn_bias = None
366
+
367
+ v = xops.memory_efficient_attention(
368
+ rearrange(q, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
369
+ rearrange(k, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
370
+ rearrange(v, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
371
+ scale=1.0 / math.sqrt(c // self.num_heads),
372
+ attn_bias=attn_bias,
373
+ )
374
+ v = rearrange(v, 'b (h w) n c -> b (n c) h w', h=h, w=w).contiguous()
375
+ elif IS_SDPA_AVAILABLE:
376
+ # compute attention
377
+ q = rearrange(q, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
378
+ k = rearrange(k, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
379
+ v = rearrange(v, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
380
+
381
+ attn_bias = None
382
+
383
+ v = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
384
+ v = v.transpose(1, 2)
385
+ v = rearrange(v, 'b (h w) n c -> b (n c) h w', h=h, w=w)
386
+ else:
387
+ # compute attention
388
+ q = rearrange(q, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
389
+ k = rearrange(k, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
390
+ v = rearrange(v, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
391
+
392
+ # score = torch.bmm(q.permute(0, 2, 1), k)
393
+ score = torch.einsum('b n c k, b n c l -> b n k l', q, k)
394
+ score = score / math.sqrt(c // self.num_heads)
395
+
396
+ score = F.softmax(score, dim=2)
397
+
398
+ # attend to values
399
+ # v = v.reshape(b, c, h * w)
400
+ # v = torch.bmm(v, score.permute(0, 2, 1))
401
+ v = torch.einsum('b n c l, b n k l -> b n c k', v, score)
402
+ v = v.reshape(b, c, h, w)
403
+
404
+ v = self.proj_out(v)
405
+
406
+ return x + v
407
+
408
+
409
+ class MoVQVectorQuantizer(nn.Module):
410
+
411
+ def __init__(self, config: MoVQConfig):
412
+ super().__init__()
413
+ self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
414
+ self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
415
+
416
+ def forward(self, x: torch.Tensor):
417
+ # b t c h w -> b t h w c
418
+ b, t, c, h, w = x.shape
419
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
420
+ x_flattened = x.view(-1, c)
421
+
422
+ codebook = self.embedding.weight
423
+
424
+ d = torch.sum(x_flattened ** 2, dim=1, keepdim=True) + \
425
+ torch.sum(codebook ** 2, dim=1) - 2 * \
426
+ torch.einsum('bd,dn->bn', x_flattened, codebook.permute(1, 0))
427
+
428
+ indices = torch.argmin(d, dim=1)
429
+ indices = indices.view(b, t, h, w)
430
+ return indices
431
+
432
+
433
+ class MoVQPretrainedModel(PreTrainedModel):
434
+ """
435
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
436
+ models.
437
+ """
438
+
439
+ config_class = MoVQConfig
440
+ base_model_prefix = "movq"
441
+ main_input_name = "pixel_values"
442
+ _no_split_modules = ["MoVQResnetBlock", "MoVQAttnBlock"]
443
+
444
+ def _init_weights(self, module):
445
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)):
446
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
447
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
448
+ elif isinstance(module, nn.Linear):
449
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
450
+ if module.bias is not None:
451
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
452
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
453
+ nn.init.uniform_(module.bias, -bound, bound)
454
+ elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
455
+ nn.init.constant_(module.weight, 1)
456
+ nn.init.constant_(module.bias, 0)
457
+
458
+
459
+ class MoVQEncoder(nn.Module):
460
+ def __init__(self, config: MoVQConfig):
461
+ super().__init__()
462
+ self.config = config
463
+ self.ch = config.ch
464
+ self.num_resolutions = len(config.ch_mult)
465
+ self.num_res_blocks = config.num_res_blocks
466
+ self.in_channels = config.in_channels
467
+
468
+ # downsampling
469
+ self.conv_in = nn.Conv2d(
470
+ self.in_channels,
471
+ self.ch,
472
+ kernel_size=3,
473
+ stride=1,
474
+ padding=1
475
+ )
476
+
477
+ in_ch_mult = (1,) + tuple(config.ch_mult)
478
+ self.down = nn.ModuleList()
479
+ for i_level in range(self.num_resolutions):
480
+ block = nn.ModuleList()
481
+ attn = nn.ModuleList()
482
+ block_in = config.ch * in_ch_mult[i_level]
483
+ block_out = config.ch * config.ch_mult[i_level]
484
+ for i_block in range(self.num_res_blocks):
485
+ block.append(
486
+ MoVQResnetBlock(
487
+ in_channels=block_in,
488
+ out_channels=block_out,
489
+ dropout=config.dropout,
490
+ )
491
+ )
492
+ block_in = block_out
493
+ if i_level in config.attn_resolutions:
494
+ attn.append(MoVQAttnBlock(block_in))
495
+
496
+ down = nn.Module()
497
+ down.block = block
498
+ down.attn = attn
499
+ if i_level != self.num_resolutions - 1:
500
+ if config.use_dc_up_down_blocks:
501
+ down.downsample = DCDownBlock2d(block_in)
502
+ else:
503
+ down.downsample = MoVQDownsample(block_in)
504
+
505
+ self.down.append(down)
506
+
507
+ # middle
508
+ self.mid = nn.Module()
509
+ self.mid.block_1 = MoVQResnetBlock(
510
+ in_channels=block_in,
511
+ out_channels=block_in,
512
+ dropout=config.dropout,
513
+ )
514
+ self.mid.attn_1 = MoVQAttnBlock(block_in)
515
+ self.mid.block_2 = MoVQResnetBlock(
516
+ in_channels=block_in,
517
+ out_channels=block_in,
518
+ dropout=config.dropout,
519
+ )
520
+
521
+ # end
522
+
523
+ self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
524
+
525
+ self.act = MoVQActivation()
526
+
527
+ out_z_channels = 2 * config.z_channels if config.double_z else config.z_channels
528
+ self.conv_out = nn.Conv2d(
529
+ block_in,
530
+ out_z_channels,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1,
534
+ )
535
+
536
+ self.out_shortcut_average_group_size = block_in // out_z_channels
537
+
538
+ def forward(self, x: torch.Tensor):
539
+
540
+ # downsampling
541
+ h = self.conv_in(x)
542
+ for i_level in range(self.num_resolutions):
543
+ for i_block in range(self.num_res_blocks):
544
+ h = self.down[i_level].block[i_block](h)
545
+ if len(self.down[i_level].attn) > 0:
546
+ h = self.down[i_level].attn[i_block](h)
547
+
548
+ if i_level != self.num_resolutions - 1:
549
+ h = self.down[i_level].downsample(h)
550
+
551
+ h = self.mid.block_1(h)
552
+ h = self.mid.attn_1(h)
553
+ h = self.mid.block_2(h)
554
+
555
+ # end
556
+ h = self.norm_out(h)
557
+ h = self.act(h)
558
+
559
+ if self.config.use_dc_up_down_blocks:
560
+ x = h.unflatten(1, (-1, self.out_shortcut_average_group_size))
561
+ x = x.mean(dim=2)
562
+ h = self.conv_out(h) + x
563
+ else:
564
+ h = self.conv_out(h)
565
+ return h
566
+
567
+
568
+ class MoVQDecoder(nn.Module):
569
+ def __init__(self, config: MoVQConfig):
570
+ super().__init__()
571
+ self.config = config
572
+ self.ch = config.ch
573
+ self.num_resolutions = len(config.ch_mult)
574
+ self.num_res_blocks = config.num_res_blocks
575
+
576
+ in_ch_mult = (1,) + tuple(config.ch_mult)
577
+ zq_ch = config.embed_dim
578
+
579
+ block_in = config.ch * config.ch_mult[-1]
580
+
581
+ self.in_shortcut_repeats = block_in // config.embed_dim
582
+
583
+ self.conv_in = nn.Conv2d(
584
+ config.z_channels,
585
+ block_in,
586
+ kernel_size=3,
587
+ stride=1,
588
+ padding=1,
589
+ )
590
+
591
+ # middle
592
+ self.mid = nn.Module()
593
+ self.mid.block_1 = MoVQResnetBlock(
594
+ in_channels=block_in,
595
+ out_channels=block_in,
596
+ dropout=config.dropout,
597
+ zq_ch=zq_ch,
598
+ )
599
+ self.mid.attn_1 = MoVQAttnBlock(block_in, zq_ch)
600
+ self.mid.block_2 = MoVQResnetBlock(
601
+ in_channels=block_in,
602
+ out_channels=block_in,
603
+ dropout=config.dropout,
604
+ zq_ch=zq_ch,
605
+ )
606
+
607
+ # upsampling
608
+ self.up = nn.ModuleList()
609
+ for i_level in reversed(range(self.num_resolutions)):
610
+ block = nn.ModuleList()
611
+ attn = nn.ModuleList()
612
+ block_out = config.ch * config.ch_mult[i_level]
613
+ for i_block in range(self.num_res_blocks + 1):
614
+ block.append(
615
+ MoVQResnetBlock(
616
+ in_channels=block_in,
617
+ out_channels=block_out,
618
+ dropout=config.dropout,
619
+ zq_ch=zq_ch,
620
+ )
621
+ )
622
+ block_in = block_out
623
+ if i_level in config.attn_resolutions:
624
+ attn.append(MoVQAttnBlock(block_in, zq_ch))
625
+
626
+ up = nn.Module()
627
+ up.block = block
628
+ up.attn = attn
629
+ if i_level != 0:
630
+ if config.use_dc_up_down_blocks:
631
+ up.upsample = DCUpBlock2d(block_in)
632
+ else:
633
+ up.upsample = MoVQUpsample(block_in)
634
+
635
+ self.up.insert(0, up)
636
+
637
+ self.act = MoVQActivation()
638
+
639
+ self.norm_out = MoVQSpatialNorm(block_in, zq_ch)
640
+ self.conv_out = nn.Conv2d(
641
+ block_in,
642
+ config.out_channels,
643
+ kernel_size=3,
644
+ stride=1,
645
+ padding=1,
646
+ )
647
+
648
+ @property
649
+ def last_layer(self):
650
+ return self.conv_out.weight
651
+
652
+ def forward(self, z: torch.Tensor, zq: torch.Tensor):
653
+ h = z
654
+
655
+ if self.config.use_dc_up_down_blocks:
656
+ h = h.repeat_interleave(self.in_shortcut_repeats, dim=1)
657
+ h = self.conv_in(z) + h
658
+ else:
659
+ h = self.conv_in(h)
660
+
661
+ # middle
662
+ h = self.mid.block_1(h, zq)
663
+ h = self.mid.attn_1(h, zq)
664
+ h = self.mid.block_2(h, zq)
665
+
666
+ # upsampling
667
+ for i_level in reversed(range(self.num_resolutions)):
668
+ for i_block in range(self.num_res_blocks + 1):
669
+ h = self.up[i_level].block[i_block](h, zq)
670
+ if len(self.up[i_level].attn) > 0:
671
+ h = self.up[i_level].attn[i_block](h, zq)
672
+
673
+ if i_level != 0:
674
+ h = self.up[i_level].upsample(h)
675
+
676
+ h = self.norm_out(h, zq)
677
+ h = self.act(h)
678
+ h = self.conv_out(h)
679
+
680
+ return h
681
+
682
+
683
+ class Decoder(nn.Module):
684
+ def __init__(self, config: MoVQConfig):
685
+ super().__init__()
686
+ self.config = config
687
+ self.ch = config.ch
688
+ self.num_resolutions = len(config.ch_mult)
689
+ self.num_res_blocks = config.num_res_blocks
690
+
691
+ in_ch_mult = (1,) + tuple(config.ch_mult)
692
+
693
+ block_in = config.ch * config.ch_mult[-1]
694
+
695
+ self.conv_in = nn.Conv2d(
696
+ config.z_channels,
697
+ block_in,
698
+ kernel_size=3,
699
+ stride=1,
700
+ padding=1,
701
+ )
702
+
703
+ # middle
704
+ self.mid = nn.Module()
705
+ self.mid.block_1 = MoVQResnetBlock(
706
+ in_channels=block_in,
707
+ out_channels=block_in,
708
+ dropout=config.dropout,
709
+ )
710
+ self.mid.attn_1 = MoVQAttnBlock(block_in)
711
+ self.mid.block_2 = MoVQResnetBlock(
712
+ in_channels=block_in,
713
+ out_channels=block_in,
714
+ dropout=config.dropout,
715
+ )
716
+
717
+ # upsampling
718
+ self.up = nn.ModuleList()
719
+ for i_level in reversed(range(self.num_resolutions)):
720
+ block = nn.ModuleList()
721
+ attn = nn.ModuleList()
722
+ block_out = config.ch * config.ch_mult[i_level]
723
+ for i_block in range(self.num_res_blocks + 1):
724
+ block.append(
725
+ MoVQResnetBlock(
726
+ in_channels=block_in,
727
+ out_channels=block_out,
728
+ dropout=config.dropout,
729
+ )
730
+ )
731
+ block_in = block_out
732
+ if i_level in config.attn_resolutions:
733
+ attn.append(MoVQAttnBlock(block_in))
734
+
735
+ up = nn.Module()
736
+ up.block = block
737
+ up.attn = attn
738
+ if i_level != 0:
739
+ up.upsample = MoVQUpsample(block_in)
740
+
741
+ self.up.insert(0, up)
742
+
743
+ self.act = MoVQActivation()
744
+
745
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
746
+ self.norm_out = nn.GroupNorm(num_channels=block_in, **norm_kwargs)
747
+ self.conv_out = nn.Conv2d(
748
+ block_in,
749
+ config.out_channels,
750
+ kernel_size=3,
751
+ stride=1,
752
+ padding=1,
753
+ )
754
+
755
+ @property
756
+ def last_layer(self):
757
+ return self.conv_out.weight
758
+
759
+ def forward(self, z: torch.Tensor, zq: torch.Tensor):
760
+ h = z
761
+ h = self.conv_in(h)
762
+
763
+ # middle
764
+ h = self.mid.block_1(h)
765
+ h = self.mid.attn_1(h)
766
+ h = self.mid.block_2(h)
767
+
768
+ # upsampling
769
+ for i_level in reversed(range(self.num_resolutions)):
770
+ for i_block in range(self.num_res_blocks + 1):
771
+ h = self.up[i_level].block[i_block](h)
772
+ if len(self.up[i_level].attn) > 0:
773
+ h = self.up[i_level].attn[i_block](h)
774
+
775
+ if i_level != 0:
776
+ h = self.up[i_level].upsample(h)
777
+
778
+ h = self.norm_out(h)
779
+ h = self.act(h)
780
+ h = self.conv_out(h)
781
+
782
+ return h
783
+
784
+
785
+ class MoVQModel(MoVQPretrainedModel):
786
+
787
+ def __init__(self, config):
788
+ super().__init__(config)
789
+ self.config = config
790
+
791
+ self.encoder = MoVQEncoder(config)
792
+ self.decoder = MoVQDecoder(config)
793
+ self.quantize = MoVQVectorQuantizer(config)
794
+
795
+ self.quant_conv = nn.Conv2d(config.z_channels, config.embed_dim, 1)
796
+ self.post_quant_conv = nn.Conv2d(config.embed_dim, config.z_channels, 1)
797
+
798
+ self.spatial_scale_factor = 2 ** (len(config.ch_mult) - 1)
799
+
800
+ self.post_init()
801
+
802
+ def encode(self, x: torch.Tensor):
803
+ h = self.encoder(x)
804
+ h = self.quant_conv(h)
805
+ codes = self.quantize(h)
806
+ return codes
807
+
808
+ def decode(self, x: torch.Tensor):
809
+ quant = self.quantize.embedding(x.flatten())
810
+ b, h, w, c = quant.shape
811
+ quant = quant.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
812
+ quant2 = self.post_quant_conv(quant)
813
+ image = self.decoder(quant2, quant)
814
+ image = image.reshape(
815
+ b,
816
+ self.config.out_channels,
817
+ h * self.spatial_scale_factor,
818
+ w * self.spatial_scale_factor,
819
+ )
820
+ return image
821
+
822
+ @property
823
+ def device(self):
824
+ return next(self.parameters()).device
825
+
826
+ @property
827
+ def dtype(self):
828
+ return next(self.parameters()).dtype
modeling_qwen2vit.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch Qwen2-VL model."""
21
+
22
+ import math
23
+ from dataclasses import dataclass
24
+ from typing import Any, Dict, List, Optional, Tuple, Union
25
+
26
+ from torch import Tensor
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+
33
+ from transformers.activations import ACT2FN
34
+ from transformers.cache_utils import Cache, StaticCache
35
+ from transformers.modeling_attn_mask_utils import (
36
+ AttentionMaskConverter,
37
+ )
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPast,
40
+ ModelOutput,
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel
43
+ from transformers.utils import (
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ is_torch_npu_available,
47
+ is_flash_attn_2_available,
48
+ is_flash_attn_greater_or_equal_2_10,
49
+ logging,
50
+ replace_return_docstrings,
51
+ )
52
+ from .configuration_qwen2vit import Qwen2VLConfig, Qwen2VLVisionConfig
53
+ from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
54
+
55
+ from einops import rearrange
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CONFIG_FOR_DOC = "Qwen2VLConfig"
60
+
61
+ try:
62
+ import xformers.ops as xops
63
+
64
+ is_xformers_available = True
65
+ except Exception as e:
66
+ is_xformers_available = False
67
+
68
+ if is_flash_attn_2_available():
69
+ from flash_attn import flash_attn_varlen_func
70
+
71
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
72
+ else:
73
+ flash_attn_varlen_func = None
74
+
75
+
76
+ def init_weights(m):
77
+ if isinstance(m, nn.Linear):
78
+ # we use xavier_uniform following official JAX ViT:
79
+ torch.nn.init.xavier_uniform_(m.weight)
80
+ if m.bias is not None:
81
+ nn.init.constant_(m.bias, 0)
82
+ elif isinstance(m, nn.nn.LayerNorm):
83
+ nn.init.constant_(m.bias, 0)
84
+ nn.init.constant_(m.weight, 1.0)
85
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
86
+ w = m.weight.data
87
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
88
+
89
+
90
+ @dataclass
91
+ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
92
+ """
93
+ Base class for Qwen2VL causal language model (or autoregressive) outputs.
94
+
95
+ Args:
96
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
97
+ Language modeling loss (for next-token prediction).
98
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
99
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
100
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
101
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
102
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
103
+
104
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
105
+ `past_key_values` input) to speed up sequential decoding.
106
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
107
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
108
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
109
+
110
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
111
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
112
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
113
+ sequence_length)`.
114
+
115
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
116
+ heads.
117
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
118
+ The rope index difference between sequence length and multimodal rope.
119
+ """
120
+
121
+ loss: Optional[torch.FloatTensor] = None
122
+ logits: torch.FloatTensor = None
123
+ past_key_values: Optional[List[torch.FloatTensor]] = None
124
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
125
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
126
+ rope_deltas: Optional[torch.LongTensor] = None
127
+
128
+
129
+ class Qwen2VLRotaryEmbedding(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim=None,
133
+ max_position_embeddings=2048,
134
+ base=10000,
135
+ device=None,
136
+ scaling_factor=1.0,
137
+ rope_type="default",
138
+ config: Optional[Qwen2VLConfig] = None,
139
+ ):
140
+ super().__init__()
141
+ # TODO (joao): remove the `if` below, only used for BC
142
+ self.rope_kwargs = {}
143
+ if config is None:
144
+ logger.warning_once(
145
+ "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
146
+ "`config` argument. All other arguments will be removed in v4.46"
147
+ )
148
+ self.rope_kwargs = {
149
+ "rope_type": rope_type,
150
+ "factor": scaling_factor,
151
+ "dim": dim,
152
+ "base": base,
153
+ "max_position_embeddings": max_position_embeddings,
154
+ }
155
+ self.rope_type = rope_type
156
+ self.max_seq_len_cached = max_position_embeddings
157
+ self.original_max_seq_len = max_position_embeddings
158
+ else:
159
+ # BC: "rope_type" was originally "type"
160
+ if config.rope_scaling is not None:
161
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
162
+ else:
163
+ self.rope_type = "default"
164
+ self.max_seq_len_cached = config.max_position_embeddings
165
+ self.original_max_seq_len = config.max_position_embeddings
166
+
167
+ self.config = config
168
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
169
+
170
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
171
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
172
+ self.original_inv_freq = self.inv_freq
173
+
174
+ def _dynamic_frequency_update(self, position_ids, device):
175
+ """
176
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
177
+ 1 - growing beyond the cached sequence length (allow scaling)
178
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179
+ """
180
+ seq_len = torch.max(position_ids) + 1
181
+ if seq_len > self.max_seq_len_cached: # growth
182
+ inv_freq, self.attention_scaling = self.rope_init_fn(
183
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
184
+ )
185
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
186
+ self.max_seq_len_cached = seq_len
187
+
188
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
189
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
190
+ self.max_seq_len_cached = self.original_max_seq_len
191
+
192
+ @torch.no_grad()
193
+ def forward(self, x, position_ids):
194
+ if "dynamic" in self.rope_type:
195
+ self._dynamic_frequency_update(position_ids, device=x.device)
196
+
197
+ # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
198
+ # So we expand the inv_freq to shape (3, ...)
199
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
200
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
201
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
202
+ device_type = x.device.type
203
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
204
+ with torch.autocast(device_type=device_type, enabled=False):
205
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
206
+ emb = torch.cat((freqs, freqs), dim=-1)
207
+ cos = emb.cos()
208
+ sin = emb.sin()
209
+
210
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
211
+ cos = cos * self.attention_scaling
212
+ sin = sin * self.attention_scaling
213
+
214
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
215
+
216
+
217
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
218
+ def rotate_half(x):
219
+ """Rotates half the hidden dims of the input."""
220
+ x1 = x[..., : x.shape[-1] // 2]
221
+ x2 = x[..., x.shape[-1] // 2:]
222
+ return torch.cat((-x2, x1), dim=-1)
223
+
224
+
225
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
226
+ orig_dtype = tensor.dtype
227
+ tensor = tensor.float()
228
+ cos = freqs.cos()
229
+ sin = freqs.sin()
230
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
231
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
232
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
233
+ output = output.to(orig_dtype)
234
+ return output
235
+
236
+
237
+ def apply_rotary_pos_emb_vision_batch(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
238
+ orig_dtype = tensor.dtype
239
+ tensor = tensor.float()
240
+ cos = freqs.cos()
241
+ sin = freqs.sin()
242
+ cos = cos.repeat(1, 1, 1, 2).float()
243
+ sin = sin.repeat(1, 1, 1, 2).float()
244
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
245
+ output = output.to(orig_dtype)
246
+ return output
247
+
248
+
249
+ class VisionRotaryEmbedding(nn.Module):
250
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
251
+ super().__init__()
252
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
253
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
254
+
255
+ def forward(self, seqlen: int, scale_factor: float = 1.0) -> torch.Tensor:
256
+ # 使用 scale_factor 动态调整 inv_freq
257
+ scaled_inv_freq = self.inv_freq * scale_factor
258
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
259
+ freqs = torch.outer(seq, scaled_inv_freq)
260
+ return freqs
261
+
262
+
263
+ class PatchEmbed(nn.Module):
264
+ def __init__(
265
+ self,
266
+ patch_size: int = 14,
267
+ temporal_patch_size: int = 2,
268
+ in_channels: int = 3,
269
+ embed_dim: int = 1152,
270
+ ) -> None:
271
+ super().__init__()
272
+ self.patch_size = patch_size
273
+ self.temporal_patch_size = temporal_patch_size
274
+ self.in_channels = in_channels
275
+ self.embed_dim = embed_dim
276
+
277
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
278
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
279
+
280
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
281
+ target_dtype = self.proj.weight.dtype
282
+ if is_torch_npu_available():
283
+ # if True:
284
+ hidden_states = F.linear(hidden_states, self.proj.weight.view(self.embed_dim, -1))
285
+ else:
286
+ hidden_states = hidden_states.view(
287
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
288
+ )
289
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
290
+ return hidden_states
291
+
292
+
293
+ class PatchMerger(nn.Module):
294
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
295
+ super().__init__()
296
+ self.hidden_size = context_dim * (spatial_merge_size ** 2)
297
+ self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
298
+ self.mlp = nn.Sequential(
299
+ nn.Linear(self.hidden_size, self.hidden_size),
300
+ nn.GELU(),
301
+ nn.Linear(self.hidden_size, dim),
302
+ )
303
+
304
+ def forward(self, x: torch.Tensor, grid_thw) -> torch.Tensor:
305
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
306
+ return x
307
+
308
+
309
+ class VisionMlp(nn.Module):
310
+ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
311
+ super().__init__()
312
+ self.fc1 = nn.Linear(dim, hidden_dim)
313
+ self.act = ACT2FN[hidden_act]
314
+ self.fc2 = nn.Linear(hidden_dim, dim)
315
+
316
+ def forward(self, x) -> torch.Tensor:
317
+ return self.fc2(self.act(self.fc1(x)))
318
+
319
+
320
+ class VisionAttention(nn.Module):
321
+ def __init__(self, dim: int, num_heads: int = 16, ) -> None:
322
+ super().__init__()
323
+ self.num_heads = num_heads
324
+ self.head_dim = dim // num_heads
325
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
326
+ self.proj = nn.Linear(dim, dim)
327
+
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ cu_seqlens: torch.Tensor,
332
+ rotary_pos_emb: torch.Tensor = None
333
+ ) -> torch.Tensor:
334
+ seq_length = hidden_states.shape[0]
335
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
336
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
337
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
338
+
339
+ attention_mask = torch.full(
340
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
341
+ )
342
+ for i in range(1, len(cu_seqlens)):
343
+ attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0
344
+
345
+ q = q.transpose(0, 1)
346
+ k = k.transpose(0, 1)
347
+ v = v.transpose(0, 1)
348
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
349
+ attn_weights = attn_weights + attention_mask
350
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
351
+ attn_output = torch.matmul(attn_weights, v)
352
+ attn_output = attn_output.transpose(0, 1)
353
+ attn_output = attn_output.reshape(seq_length, -1)
354
+ attn_output = self.proj(attn_output)
355
+ return attn_output
356
+
357
+
358
+ class BatchVisionAttention(nn.Module):
359
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
360
+ super().__init__()
361
+ self.num_heads = num_heads
362
+ self.head_dim = dim // num_heads
363
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
364
+ self.proj = nn.Linear(dim, dim)
365
+
366
+ def forward(
367
+ self,
368
+ hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
369
+ attention_mask: torch.Tensor, # [batch_size, 1, 1, seq_len]
370
+ rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
371
+ ) -> torch.Tensor:
372
+ batch_size, seq_len, _ = hidden_states.shape
373
+
374
+ q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0,
375
+ 3, 1,
376
+ 4).unbind(
377
+ 0)
378
+ # [batch_size, num_heads, seq_len, head_dim]
379
+
380
+ if rotary_pos_emb is not None:
381
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
382
+ q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
383
+ k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
384
+
385
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
386
+ if attention_mask is not None:
387
+ attn_weights = attn_weights + attention_mask
388
+
389
+ # Softmax
390
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
391
+
392
+ attn_output = torch.matmul(attn_weights, v) # [batch_size, num_heads, seq_len, head_dim]
393
+ attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)
394
+ return self.proj(attn_output)
395
+
396
+
397
+ class VisionXformerAttention(nn.Module):
398
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
399
+ super().__init__()
400
+ self.num_heads = num_heads
401
+ self.head_dim = dim // num_heads
402
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
403
+ self.proj = nn.Linear(dim, dim)
404
+
405
+ def forward(
406
+ self,
407
+ hidden_states: torch.Tensor,
408
+ cu_seqlens: torch.Tensor,
409
+ rotary_pos_emb: torch.Tensor = None
410
+ ) -> torch.Tensor:
411
+ seq_length = hidden_states.shape[0]
412
+
413
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
414
+
415
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
416
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)
417
+
418
+ seqlens = [cu_seqlens[0]] + [cu_seqlens[i] - cu_seqlens[i - 1] for i in range(1, len(cu_seqlens))]
419
+ attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(seqlens)
420
+
421
+ attn_output = xops.memory_efficient_attention(
422
+ q, k, v.unsqueeze(0),
423
+ attn_bias=attn_bias,
424
+ scale=1.0 / math.sqrt(self.head_dim)
425
+ )
426
+ attn_output = attn_output.reshape(seq_length, -1)
427
+ attn_output = self.proj(attn_output)
428
+ return attn_output
429
+
430
+
431
+ class BatchVisionXformerAttention(nn.Module):
432
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
433
+ super().__init__()
434
+ self.num_heads = num_heads
435
+ self.head_dim = dim // num_heads
436
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
437
+ self.proj = nn.Linear(dim, dim)
438
+
439
+ def forward(
440
+ self,
441
+ hidden_states: torch.Tensor,
442
+ attention_mask: torch.Tensor, # [batch_size, 1, 1, seq_len]
443
+ rotary_pos_emb: torch.Tensor = None
444
+ ) -> torch.Tensor:
445
+ seq_length = hidden_states.shape[0]
446
+ batch_size, seq_len = hidden_states.shape
447
+
448
+ q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0,
449
+ 3, 1,
450
+ 4).unbind(
451
+ 0)
452
+ # [batch_size, num_heads, seq_len, head_dim]
453
+
454
+ if rotary_pos_emb is not None:
455
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
456
+ q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
457
+ k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
458
+
459
+ attn_output = xops.memory_efficient_attention(
460
+ q, k, v,
461
+ attn_bias=attention_mask,
462
+ scale=1.0 / math.sqrt(self.head_dim)
463
+ )
464
+ attn_output = attn_output.reshape(batch_size, seq_len, -1)
465
+ return self.proj(attn_output)
466
+
467
+
468
+ class VisionFlashAttention2(nn.Module):
469
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
470
+ super().__init__()
471
+ self.num_heads = num_heads
472
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
473
+ self.proj = nn.Linear(dim, dim)
474
+
475
+ def forward(
476
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
477
+ ) -> torch.Tensor:
478
+ seq_length = hidden_states.shape[0]
479
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
480
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
481
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
482
+
483
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
484
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
485
+ seq_length, -1
486
+ )
487
+ attn_output = self.proj(attn_output)
488
+ return attn_output
489
+
490
+
491
+ class BatchVisionFlashAttention2(nn.Module):
492
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
493
+ super().__init__()
494
+ self.num_heads = num_heads
495
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
496
+ self.proj = nn.Linear(dim, dim)
497
+
498
+ def forward(
499
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
500
+ ) -> torch.Tensor:
501
+ batch_size, seq_len, _ = hidden_states.shape
502
+
503
+ q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1,
504
+ 4).unbind(0)
505
+
506
+ if rotary_pos_emb is not None:
507
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
508
+ q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
509
+ k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
510
+
511
+ q = rearrange(q, 'b h l d -> b l h d')
512
+ k = rearrange(k, 'b h l d -> b l h d')
513
+ v = rearrange(v, 'b h l d -> b l h d')
514
+
515
+ attn_output = _flash_attention_forward(q, k, v).reshape(batch_size, seq_len, -1)
516
+ attn_output = self.proj(attn_output)
517
+ return attn_output
518
+
519
+
520
+ class VisionSdpaAttention(nn.Module):
521
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
522
+ super().__init__()
523
+ self.num_heads = num_heads
524
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
525
+ self.proj = nn.Linear(dim, dim)
526
+
527
+ def forward(
528
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
529
+ ) -> torch.Tensor:
530
+ seq_length = hidden_states.shape[0]
531
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
532
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
533
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
534
+
535
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
536
+ for i in range(1, len(cu_seqlens)):
537
+ attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
538
+ q = q.transpose(0, 1)
539
+ k = k.transpose(0, 1)
540
+ v = v.transpose(0, 1)
541
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
542
+ attn_output = attn_output.transpose(0, 1)
543
+ attn_output = attn_output.reshape(seq_length, -1)
544
+ attn_output = self.proj(attn_output)
545
+ return attn_output
546
+
547
+
548
+ class BatchVisionSdpaAttention(nn.Module):
549
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
550
+ super().__init__()
551
+ self.num_heads = num_heads
552
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
553
+ self.proj = nn.Linear(dim, dim)
554
+
555
+ def forward(
556
+ self,
557
+ hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
558
+ attention_mask: torch.Tensor = None, # [batch_size, 1, 1, seq_len]
559
+ rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
560
+ ) -> torch.Tensor:
561
+ batch_size, seq_len, _ = hidden_states.shape
562
+ q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1,
563
+ 4).unbind(0)
564
+ # [batch_size, num_heads, seq_len, head_dim]
565
+
566
+ if rotary_pos_emb is not None:
567
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
568
+ q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
569
+ k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
570
+
571
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
572
+ attn_output = attn_output.transpose(1, 2)
573
+ attn_output = attn_output.reshape(batch_size, seq_len, -1)
574
+ attn_output = self.proj(attn_output)
575
+ return attn_output
576
+
577
+
578
+ QWEN2_VL_VISION_ATTENTION_CLASSES = {
579
+ "eager": VisionAttention,
580
+ "flash_attention_2": VisionFlashAttention2,
581
+ "sdpa": VisionSdpaAttention,
582
+ "xformers": VisionXformerAttention,
583
+ }
584
+
585
+ QWEN2_VL_VISION_BATCH_ATTENTION_CLASSES = {
586
+ "eager": BatchVisionAttention,
587
+ "flash_attention_2": VisionFlashAttention2,
588
+ "sdpa": BatchVisionSdpaAttention,
589
+ }
590
+
591
+
592
+ class Qwen2VLVisionBlock(nn.Module):
593
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
594
+ super().__init__()
595
+
596
+ self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
597
+ self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
598
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
599
+
600
+ self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
601
+ config.embed_dim, num_heads=config.num_heads,
602
+ )
603
+ self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
604
+
605
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb, grid_thw) -> torch.Tensor:
606
+ hidden_states = hidden_states + self.attn(
607
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
608
+ )
609
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
610
+ return hidden_states
611
+
612
+
613
+ class Qwen2VLBatchVisionBlock(nn.Module):
614
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
615
+ super().__init__()
616
+ self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
617
+ self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
618
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
619
+
620
+ self.attn = QWEN2_VL_VISION_BATCH_ATTENTION_CLASSES[attn_implementation](
621
+ config.embed_dim, num_heads=config.num_heads,
622
+ )
623
+ self.mlp = VisionMlp(config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
624
+
625
+ def forward(
626
+ self,
627
+ hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
628
+ attention_mask: torch.Tensor = None, # [batch_size, 1, 1, seq_len]
629
+ rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
630
+ ) -> torch.Tensor:
631
+ # Attention
632
+ hidden_states = hidden_states + self.attn(
633
+ self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb
634
+ )
635
+ # MLP
636
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
637
+ return hidden_states
638
+
639
+
640
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
641
+ def _prepare_4d_causal_attention_mask_with_cache_position(
642
+ attention_mask: torch.Tensor,
643
+ sequence_length: int,
644
+ target_length: int,
645
+ dtype: torch.dtype,
646
+ device: torch.device,
647
+ min_dtype: float,
648
+ cache_position: torch.Tensor,
649
+ batch_size: int,
650
+ ):
651
+ """
652
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
653
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
654
+
655
+ Args:
656
+ attention_mask (`torch.Tensor`):
657
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
658
+ sequence_length (`int`):
659
+ The sequence length being processed.
660
+ target_length (`int`):
661
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
662
+ dtype (`torch.dtype`):
663
+ The dtype to use for the 4D attention mask.
664
+ device (`torch.device`):
665
+ The device to plcae the 4D attention mask on.
666
+ min_dtype (`float`):
667
+ The minimum value representable with the dtype `dtype`.
668
+ cache_position (`torch.Tensor`):
669
+ Indices depicting the position of the input sequence tokens in the sequence.
670
+ batch_size (`torch.Tensor`):
671
+ Batch size.
672
+ """
673
+ if attention_mask is not None and attention_mask.dim() == 4:
674
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
675
+ causal_mask = attention_mask
676
+ else:
677
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
678
+ if sequence_length != 1:
679
+ causal_mask = torch.triu(causal_mask, diagonal=1)
680
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
681
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
682
+ if attention_mask is not None:
683
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
684
+ mask_length = attention_mask.shape[-1]
685
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
686
+ padding_mask = padding_mask == 0
687
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
688
+ padding_mask, min_dtype
689
+ )
690
+
691
+ return causal_mask
692
+
693
+
694
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
695
+ class Qwen2RMSNorm(nn.Module):
696
+ def __init__(self, hidden_size, eps=1e-6):
697
+ """
698
+ Qwen2RMSNorm is equivalent to T5nn.LayerNorm
699
+ """
700
+ super().__init__()
701
+ self.weight = nn.Parameter(torch.ones(hidden_size))
702
+ self.variance_epsilon = eps
703
+
704
+ def forward(self, hidden_states):
705
+ input_dtype = hidden_states.dtype
706
+ hidden_states = hidden_states.to(torch.float32)
707
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
708
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
709
+ return self.weight * hidden_states.to(input_dtype)
710
+
711
+ def extra_repr(self):
712
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
713
+
714
+
715
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
716
+ class Qwen2MLP(nn.Module):
717
+ def __init__(self, config):
718
+ super().__init__()
719
+ self.hidden_size = config.hidden_size
720
+ self.intermediate_size = config.intermediate_size
721
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
722
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
723
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
724
+ self.act_fn = ACT2FN[config.hidden_act]
725
+
726
+ def forward(self, hidden_state):
727
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
728
+
729
+
730
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
731
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
732
+ """
733
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
734
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
735
+ """
736
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
737
+ if n_rep == 1:
738
+ return hidden_states
739
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
740
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
741
+
742
+
743
+ class Qwen2VLPreTrainedModel(PreTrainedModel):
744
+ config_class = Qwen2VLConfig
745
+ base_model_prefix = "model"
746
+ supports_gradient_checkpointing = True
747
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
748
+ _skip_keys_device_placement = "past_key_values"
749
+ _supports_flash_attn_2 = True
750
+ _supports_sdpa = True
751
+ _supports_cache_class = True
752
+ _supports_static_cache = True
753
+
754
+ def _init_weights(self, module):
755
+ std = self.config.initializer_range
756
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
757
+ module.weight.data.normal_(mean=0.0, std=std)
758
+ if module.bias is not None:
759
+ module.bias.data.zero_()
760
+ elif isinstance(module, nn.Embedding):
761
+ module.weight.data.normal_(mean=0.0, std=std)
762
+ if module.padding_idx is not None:
763
+ module.weight.data[module.padding_idx].zero_()
764
+
765
+
766
+ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
767
+ config_class = Qwen2VLVisionConfig
768
+ _no_split_modules = ["Qwen2VLVisionBlock"]
769
+
770
+ def __init__(self, config) -> None:
771
+ super().__init__(config)
772
+ self.spatial_merge_size = config.spatial_merge_size
773
+
774
+ self.patch_embed = PatchEmbed(
775
+ patch_size=config.patch_size,
776
+ temporal_patch_size=config.temporal_patch_size,
777
+ in_channels=config.in_channels,
778
+ embed_dim=config.embed_dim,
779
+ )
780
+
781
+ head_dim = config.embed_dim // config.num_heads
782
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
783
+
784
+ self.blocks = nn.ModuleList(
785
+ [Qwen2VLVisionBlock(config, config.attn_implementation) for _ in range(config.depth)]
786
+ )
787
+ self.merger = PatchMerger(
788
+ dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
789
+ )
790
+
791
+ def get_dtype(self) -> torch.dtype:
792
+ return self.blocks[0].mlp.fc2.weight.dtype
793
+
794
+ def get_device(self) -> torch.device:
795
+ return self.blocks[0].mlp.fc2.weight.device
796
+
797
+ def rot_pos_emb(self, grid_thw):
798
+ pos_ids = []
799
+ for t, h, w in grid_thw:
800
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
801
+ hpos_ids = hpos_ids.reshape(
802
+ h // self.spatial_merge_size,
803
+ self.spatial_merge_size,
804
+ w // self.spatial_merge_size,
805
+ self.spatial_merge_size,
806
+ )
807
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
808
+ hpos_ids = hpos_ids.flatten()
809
+
810
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
811
+ wpos_ids = wpos_ids.reshape(
812
+ h // self.spatial_merge_size,
813
+ self.spatial_merge_size,
814
+ w // self.spatial_merge_size,
815
+ self.spatial_merge_size,
816
+ )
817
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
818
+ wpos_ids = wpos_ids.flatten()
819
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
820
+ pos_ids = torch.cat(pos_ids, dim=0)
821
+ max_grid_size = grid_thw[:, 1:].max()
822
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
823
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
824
+ return rotary_pos_emb
825
+
826
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor,
827
+ output_hidden_states=False, org_forward=False, ) -> torch.Tensor:
828
+ hidden_states = self.patch_embed(hidden_states)
829
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
830
+
831
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
832
+ dim=0, dtype=torch.int32
833
+ )
834
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
835
+
836
+ for blk in self.blocks:
837
+ hidden_states = blk(hidden_states,
838
+ cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb,
839
+ grid_thw=grid_thw)
840
+
841
+ hidden_states = self.merger(hidden_states, grid_thw)
842
+ return hidden_states
modeling_rope_utils.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Optional, Tuple
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import is_torch_available, logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ if is_torch_available():
27
+ import torch
28
+
29
+
30
+ def _compute_default_rope_parameters(
31
+ config: Optional[PretrainedConfig] = None,
32
+ device: Optional["torch.device"] = None,
33
+ seq_len: Optional[int] = None,
34
+ **rope_kwargs,
35
+ ) -> Tuple["torch.Tensor", float]:
36
+ """
37
+ Computes the inverse frequencies according to the original RoPE implementation
38
+ Args:
39
+ config ([`~transformers.PretrainedConfig`]):
40
+ The model configuration.
41
+ device (`torch.device`):
42
+ The device to use for initialization of the inverse frequencies.
43
+ seq_len (`int`, *optional*):
44
+ The current sequence length. Unused for this type of RoPE.
45
+ rope_kwargs (`Dict`, *optional*):
46
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
47
+ Returns:
48
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
49
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
50
+ """
51
+ if config is not None and len(rope_kwargs) > 0:
52
+ raise ValueError(
53
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
54
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
55
+ )
56
+ if len(rope_kwargs) > 0:
57
+ base = rope_kwargs["base"]
58
+ dim = rope_kwargs["dim"]
59
+ elif config is not None:
60
+ base = config.rope_theta
61
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
62
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
63
+ dim = int(head_dim * partial_rotary_factor)
64
+
65
+ attention_factor = 1.0 # Unused in this type of RoPE
66
+
67
+ # Compute the inverse frequencies
68
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
69
+ return inv_freq, attention_factor
70
+
71
+
72
+ def _compute_linear_scaling_rope_parameters(
73
+ config: Optional[PretrainedConfig] = None,
74
+ device: Optional["torch.device"] = None,
75
+ seq_len: Optional[int] = None,
76
+ **rope_kwargs,
77
+ ) -> Tuple["torch.Tensor", float]:
78
+ """
79
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
80
+ Args:
81
+ config ([`~transformers.PretrainedConfig`]):
82
+ The model configuration.
83
+ device (`torch.device`):
84
+ The device to use for initialization of the inverse frequencies.
85
+ seq_len (`int`, *optional*):
86
+ The current sequence length. Unused for this type of RoPE.
87
+ rope_kwargs (`Dict`, *optional*):
88
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
89
+ Returns:
90
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
91
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
92
+ """
93
+ if config is not None and len(rope_kwargs) > 0:
94
+ raise ValueError(
95
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
96
+ f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
97
+ )
98
+ if len(rope_kwargs) > 0:
99
+ factor = rope_kwargs["factor"]
100
+ elif config is not None:
101
+ factor = config.rope_scaling["factor"]
102
+
103
+ # Gets the default RoPE parameters
104
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
105
+
106
+ # Then applies linear scaling to the frequencies.
107
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
108
+ # applying scaling to the inverse frequencies is equivalent.
109
+ inv_freq /= factor
110
+ return inv_freq, attention_factor
111
+
112
+
113
+ def _compute_dynamic_ntk_parameters(
114
+ config: Optional[PretrainedConfig] = None,
115
+ device: Optional["torch.device"] = None,
116
+ seq_len: Optional[int] = None,
117
+ **rope_kwargs,
118
+ ) -> Tuple["torch.Tensor", float]:
119
+ """
120
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
121
+ Args:
122
+ config ([`~transformers.PretrainedConfig`]):
123
+ The model configuration.
124
+ device (`torch.device`):
125
+ The device to use for initialization of the inverse frequencies.
126
+ seq_len (`int`, *optional*):
127
+ The current sequence length, used to update the dynamic RoPE at inference time.
128
+ rope_kwargs (`Dict`, *optional*):
129
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
130
+ Returns:
131
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
132
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
133
+ """
134
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
135
+ if config is not None and len(rope_kwargs) > 0:
136
+ raise ValueError(
137
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
138
+ f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
139
+ )
140
+ if len(rope_kwargs) > 0:
141
+ base = rope_kwargs["base"]
142
+ dim = rope_kwargs["dim"]
143
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
144
+ factor = rope_kwargs["factor"]
145
+ elif config is not None:
146
+ base = config.rope_theta
147
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
148
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
149
+ dim = int(head_dim * partial_rotary_factor)
150
+ max_position_embeddings = config.max_position_embeddings
151
+ factor = config.rope_scaling["factor"]
152
+
153
+ attention_factor = 1.0 # Unused in this type of RoPE
154
+
155
+ # seq_len: default to max_position_embeddings, e.g. at init time
156
+ seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
157
+
158
+ # Compute the inverse frequencies
159
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
160
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
161
+ return inv_freq, attention_factor
162
+
163
+
164
+ def _compute_yarn_parameters(
165
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
166
+ ) -> Tuple["torch.Tensor", float]:
167
+ """
168
+ Computes the inverse frequencies with NTK scaling. Please refer to the
169
+ [original paper](https://arxiv.org/abs/2309.00071)
170
+ Args:
171
+ config ([`~transformers.PretrainedConfig`]):
172
+ The model configuration.
173
+ device (`torch.device`):
174
+ The device to use for initialization of the inverse frequencies.
175
+ seq_len (`int`, *optional*):
176
+ The current sequence length. Unused for this type of RoPE.
177
+ rope_kwargs (`Dict`, *optional*):
178
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
179
+ Returns:
180
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
181
+ post-processing scaling factor applied to the computed cos/sin.
182
+ """
183
+ # No need to keep BC with yarn, unreleased when this new pattern was created.
184
+ if len(rope_kwargs) > 0:
185
+ raise ValueError(
186
+ f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
187
+ )
188
+
189
+ base = config.rope_theta
190
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
191
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
192
+ dim = int(head_dim * partial_rotary_factor)
193
+ max_position_embeddings = config.max_position_embeddings
194
+ factor = config.rope_scaling["factor"]
195
+
196
+ # Sets the attention factor as suggested in the paper
197
+ attention_factor = config.rope_scaling.get("attention_factor")
198
+ if attention_factor is None:
199
+ attention_factor = 0.1 * math.log(factor) + 1.0
200
+
201
+ # Optional config options
202
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
203
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
204
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
205
+
206
+ # Compute the inverse frequencies
207
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
208
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
209
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
210
+
211
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
212
+ """Find dimension range bounds based on rotations"""
213
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
214
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
215
+ return max(low, 0), min(high, dim - 1)
216
+
217
+ def linear_ramp_factor(min, max, dim):
218
+ if min == max:
219
+ max += 0.001 # Prevent singularity
220
+
221
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
222
+ ramp_func = torch.clamp(linear_func, 0, 1)
223
+ return ramp_func
224
+
225
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
226
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
227
+ pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
228
+ inv_freq_extrapolation = 1.0 / pos_freqs
229
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
230
+
231
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
232
+
233
+ # Get n-dimensional rotational scaling corrected for extrapolation
234
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
235
+ inv_freq = (
236
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
237
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
238
+ )
239
+
240
+ return inv_freq, attention_factor
241
+
242
+
243
+ def _compute_longrope_parameters(
244
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
245
+ ) -> Tuple["torch.Tensor", float]:
246
+ """
247
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
248
+ [original implementation](https://github.com/microsoft/LongRoPE)
249
+ Args:
250
+ config ([`~transformers.PretrainedConfig`]):
251
+ The model configuration.
252
+ device (`torch.device`):
253
+ The device to use for initialization of the inverse frequencies.
254
+ seq_len (`int`, *optional*):
255
+ The current sequence length. Unused for this type of RoPE.
256
+ rope_kwargs (`Dict`, *optional*):
257
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
258
+ Returns:
259
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
260
+ post-processing scaling factor applied to the computed cos/sin.
261
+ """
262
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
263
+ # No need to keep BC with longrope, unreleased when this new pattern was created.
264
+ if len(rope_kwargs) > 0:
265
+ raise ValueError(
266
+ "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
267
+ f"{rope_kwargs}"
268
+ )
269
+
270
+ base = config.rope_theta
271
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
272
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
273
+ dim = int(head_dim * partial_rotary_factor)
274
+ long_factor = config.rope_scaling["long_factor"]
275
+ short_factor = config.rope_scaling["short_factor"]
276
+ factor = config.rope_scaling.get("factor")
277
+ attention_factor = config.rope_scaling.get("attention_factor")
278
+
279
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
280
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
281
+ # values to compute the default attention scaling factor, instead of using `factor`.
282
+ if hasattr(config, "original_max_position_embeddings"):
283
+ max_position_embeddings = config.original_max_position_embeddings
284
+ expanded_max_position_embeddings = config.max_position_embeddings
285
+ factor = expanded_max_position_embeddings / max_position_embeddings
286
+ else:
287
+ max_position_embeddings = config.max_position_embeddings
288
+ expanded_max_position_embeddings = max_position_embeddings * factor
289
+
290
+ # Sets the attention factor as suggested in the paper
291
+ if attention_factor is None:
292
+ if factor <= 1.0:
293
+ attention_factor = 1.0
294
+ else:
295
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
296
+
297
+ # Compute the inverse frequencies -- scaled based on the target sequence length
298
+ if expanded_max_position_embeddings > max_position_embeddings:
299
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
300
+ else:
301
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
302
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
303
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
304
+
305
+ return inv_freq, attention_factor
306
+
307
+
308
+ def _compute_llama3_parameters(
309
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
310
+ ) -> Tuple["torch.Tensor", float]:
311
+ """
312
+ Computes the inverse frequencies for llama 3.1.
313
+
314
+ Args:
315
+ config ([`~transformers.PretrainedConfig`]):
316
+ The model configuration.
317
+ device (`torch.device`):
318
+ The device to use for initialization of the inverse frequencies.
319
+ seq_len (`int`, *optional*):
320
+ The current sequence length. Unused for this type of RoPE.
321
+ rope_kwargs (`Dict`, *optional*):
322
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
323
+ Returns:
324
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
325
+ post-processing scaling factor applied to the computed cos/sin.
326
+ """
327
+ # Gets the default RoPE parameters
328
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
329
+
330
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
331
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
332
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
333
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
334
+
335
+ low_freq_wavelen = old_context_len / low_freq_factor
336
+ high_freq_wavelen = old_context_len / high_freq_factor
337
+
338
+ wavelen = 2 * math.pi / inv_freq
339
+ # wavelen < high_freq_wavelen: do nothing
340
+ # wavelen > low_freq_wavelen: divide by factor
341
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
342
+ # otherwise: interpolate between the two, using a smooth factor
343
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
344
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
345
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
346
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
347
+
348
+ return inv_freq_llama, attention_factor
349
+
350
+
351
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
352
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
353
+ # parameterizations, as long as the callable has the same signature.
354
+ ROPE_INIT_FUNCTIONS = {
355
+ "default": _compute_default_rope_parameters,
356
+ "linear": _compute_linear_scaling_rope_parameters,
357
+ "dynamic": _compute_dynamic_ntk_parameters,
358
+ "yarn": _compute_yarn_parameters,
359
+ "longrope": _compute_longrope_parameters,
360
+ "llama3": _compute_llama3_parameters,
361
+ }
362
+
363
+
364
+ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
365
+ """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
366
+ # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
367
+ if "type" in received_keys:
368
+ received_keys -= {"type"}
369
+ required_keys.add("rope_type")
370
+
371
+ missing_keys = required_keys - received_keys
372
+ if missing_keys:
373
+ raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
374
+
375
+ if optional_keys is not None:
376
+ unused_keys = received_keys - required_keys - optional_keys
377
+ else:
378
+ unused_keys = received_keys - required_keys
379
+ if unused_keys:
380
+ logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
381
+
382
+
383
+ def _validate_default_rope_parameters(config: PretrainedConfig):
384
+ rope_scaling = config.rope_scaling
385
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
386
+ required_keys = {"rope_type"}
387
+ received_keys = set(rope_scaling.keys())
388
+ _check_received_keys(rope_type, received_keys, required_keys)
389
+
390
+
391
+ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
392
+ rope_scaling = config.rope_scaling
393
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
394
+ required_keys = {"rope_type", "factor"}
395
+ received_keys = set(rope_scaling.keys())
396
+ _check_received_keys(rope_type, received_keys, required_keys)
397
+
398
+ factor = rope_scaling["factor"]
399
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
400
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
401
+
402
+
403
+ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
404
+ rope_scaling = config.rope_scaling
405
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
406
+ required_keys = {"rope_type", "factor"}
407
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
408
+ optional_keys = {"original_max_position_embeddings"}
409
+ received_keys = set(rope_scaling.keys())
410
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
411
+
412
+ factor = rope_scaling["factor"]
413
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
414
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
415
+
416
+
417
+ def _validate_yarn_parameters(config: PretrainedConfig):
418
+ rope_scaling = config.rope_scaling
419
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
420
+ required_keys = {"rope_type", "factor"}
421
+ optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
422
+ received_keys = set(rope_scaling.keys())
423
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
424
+
425
+ factor = rope_scaling["factor"]
426
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
427
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
428
+
429
+ attention_factor = rope_scaling.get("attention_factor")
430
+ if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
431
+ logger.warning(
432
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
433
+ )
434
+ beta_fast = rope_scaling.get("beta_fast")
435
+ if beta_fast is not None and not isinstance(beta_fast, float):
436
+ logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
437
+ beta_slow = rope_scaling.get("beta_slow")
438
+ if beta_slow is not None and not isinstance(beta_slow, float):
439
+ logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
440
+
441
+ if (beta_fast or 32) < (beta_slow or 1):
442
+ logger.warning(
443
+ f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
444
+ f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
445
+ )
446
+
447
+
448
+ def _validate_longrope_parameters(config: PretrainedConfig):
449
+ rope_scaling = config.rope_scaling
450
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
451
+ required_keys = {"rope_type", "short_factor", "long_factor"}
452
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
453
+ optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
454
+ received_keys = set(rope_scaling.keys())
455
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
456
+
457
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
458
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
459
+ dim = int(head_dim * partial_rotary_factor)
460
+
461
+ short_factor = rope_scaling.get("short_factor")
462
+ if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
463
+ logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
464
+ if not len(short_factor) == dim // 2:
465
+ logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
466
+
467
+ long_factor = rope_scaling.get("long_factor")
468
+ if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
469
+ logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
470
+ if not len(long_factor) == dim // 2:
471
+ logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
472
+
473
+ # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
474
+ # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
475
+ # unique to longrope (= undesirable)
476
+ if hasattr(config, "original_max_position_embeddings"):
477
+ logger.warning_once(
478
+ "This model has set a `original_max_position_embeddings` field, to be used together with "
479
+ "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
480
+ "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
481
+ "as it is compatible with most model architectures."
482
+ )
483
+ else:
484
+ factor = rope_scaling.get("factor")
485
+ if factor is None:
486
+ logger.warning("Missing required keys in `rope_scaling`: 'factor'")
487
+ elif not isinstance(factor, float) or factor < 1.0:
488
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
489
+
490
+ attention_factor = rope_scaling.get("attention_factor")
491
+ if attention_factor is not None:
492
+ if not isinstance(attention_factor, float) or attention_factor < 0.0:
493
+ logger.warning(
494
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
495
+ )
496
+
497
+
498
+ def _validate_llama3_parameters(config: PretrainedConfig):
499
+ rope_scaling = config.rope_scaling
500
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
501
+ required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
502
+ received_keys = set(rope_scaling.keys())
503
+ _check_received_keys(rope_type, received_keys, required_keys)
504
+
505
+ factor = rope_scaling["factor"]
506
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
507
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
508
+
509
+ low_freq_factor = rope_scaling["low_freq_factor"]
510
+ high_freq_factor = rope_scaling["high_freq_factor"]
511
+ if low_freq_factor is None or not isinstance(low_freq_factor, float):
512
+ logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
513
+ if high_freq_factor is None or not isinstance(high_freq_factor, float):
514
+ logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
515
+ if high_freq_factor <= low_freq_factor:
516
+ logger.warning(
517
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
518
+ f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
519
+ )
520
+
521
+ original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
522
+ if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
523
+ logger.warning(
524
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
525
+ f"{original_max_position_embeddings}"
526
+ )
527
+ if original_max_position_embeddings >= config.max_position_embeddings:
528
+ logger.warning(
529
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
530
+ f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
531
+ )
532
+
533
+
534
+ # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
535
+ ROPE_VALIDATION_FUNCTIONS = {
536
+ "default": _validate_default_rope_parameters,
537
+ "linear": _validate_linear_scaling_rope_parameters,
538
+ "dynamic": _validate_dynamic_scaling_rope_parameters,
539
+ "yarn": _validate_yarn_parameters,
540
+ "longrope": _validate_longrope_parameters,
541
+ "llama3": _validate_llama3_parameters,
542
+ }
543
+
544
+
545
+ def rope_config_validation(config: PretrainedConfig):
546
+ """
547
+ Validate the RoPE config arguments, given a `PretrainedConfig` object
548
+ """
549
+ rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
550
+ if rope_scaling is None:
551
+ return
552
+
553
+ # BC: "rope_type" was originally "type"
554
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
555
+ validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
556
+ if validation_fn is not None:
557
+ validation_fn(config)
558
+ else:
559
+ logger.warning(
560
+ f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
561
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_dualvitok.DualViTokImageProcessor"
4
+ },
5
+ "do_convert_rgb": true,
6
+ "do_normalize": true,
7
+ "do_rescale": true,
8
+ "do_resize": true,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_processor_type": "DualViTokImageProcessor",
15
+ "image_std": [
16
+ 0.5,
17
+ 0.5,
18
+ 0.5
19
+ ],
20
+ "max_pixels": 1048576,
21
+ "min_pixels": 1024,
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "max_pixels": 1048576,
26
+ "min_pixels": 1024
27
+ },
28
+ "spatial_factor": 16
29
+ }
processing_qwen2vit.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """
21
+ Processor class for Qwen2-VL.
22
+ """
23
+
24
+ from typing import List, Union
25
+
26
+
27
+ try:
28
+ from typing import Unpack
29
+ except ImportError:
30
+ from typing_extensions import Unpack
31
+
32
+
33
+ from transformers.feature_extraction_utils import BatchFeature
34
+ from .image_utils import ImageInput, VideoInput
35
+ from transformers.processing_utils import (
36
+ ProcessingKwargs,
37
+ ProcessorMixin,
38
+ )
39
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
40
+ from transformers.utils import logging
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ class Qwen2VLProcessorKwargs(ProcessingKwargs, total=False):
47
+ _defaults = {
48
+ "text_kwargs": {
49
+ "padding": False,
50
+ },
51
+ }
52
+
53
+
54
+ class Qwen2VLProcessor(ProcessorMixin):
55
+ r"""
56
+ Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
57
+ [`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
58
+ [`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information.
59
+ Args:
60
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
61
+ The image processor is a required input.
62
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
63
+ The tokenizer is a required input.
64
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
65
+ in a chat into a tokenizable string.
66
+ """
67
+
68
+ attributes = ["image_processor", "tokenizer"]
69
+ valid_kwargs = ["chat_template"]
70
+ image_processor_class = "Qwen2VLImageProcessor"
71
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
72
+
73
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
74
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
75
+
76
+ def __call__(
77
+ self,
78
+ images: ImageInput = None,
79
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
80
+ videos: VideoInput = None,
81
+ **kwargs: Unpack[Qwen2VLProcessorKwargs],
82
+ ) -> BatchFeature:
83
+ """
84
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
85
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
86
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
87
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
88
+
89
+ Args:
90
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
91
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
92
+ tensor. Both channels-first and channels-last formats are supported.
93
+ text (`str`, `List[str]`, `List[List[str]]`):
94
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
95
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
96
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
97
+ videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
98
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
99
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
100
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
101
+ If set, will return tensors of a particular framework. Acceptable values are:
102
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
103
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
104
+ - `'np'`: Return NumPy `np.ndarray` objects.
105
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
106
+
107
+ Returns:
108
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
109
+
110
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
111
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
112
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
113
+ `None`).
114
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
115
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
116
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
117
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
118
+ """
119
+ output_kwargs = self._merge_kwargs(
120
+ Qwen2VLProcessorKwargs,
121
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
122
+ **kwargs,
123
+ )
124
+ if images is not None:
125
+ image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
126
+ image_grid_thw = image_inputs["image_grid_thw"]
127
+ else:
128
+ image_inputs = {}
129
+ image_grid_thw = None
130
+
131
+ if videos is not None:
132
+ videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["videos_kwargs"])
133
+ video_grid_thw = videos_inputs["video_grid_thw"]
134
+ else:
135
+ videos_inputs = {}
136
+ video_grid_thw = None
137
+
138
+ if not isinstance(text, list):
139
+ text = [text]
140
+
141
+ if image_grid_thw is not None:
142
+ merge_length = self.image_processor.merge_size**2
143
+ index = 0
144
+ for i in range(len(text)):
145
+ while "<|image_pad|>" in text[i]:
146
+ text[i] = text[i].replace(
147
+ "<|image_pad|>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
148
+ )
149
+ index += 1
150
+ text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
151
+
152
+ if video_grid_thw is not None:
153
+ merge_length = self.image_processor.merge_size**2
154
+ index = 0
155
+ for i in range(len(text)):
156
+ while "<|video_pad|>" in text[i]:
157
+ text[i] = text[i].replace(
158
+ "<|video_pad|>", "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
159
+ )
160
+ index += 1
161
+ text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
162
+
163
+ _ = output_kwargs["text_kwargs"].pop("padding_side", None)
164
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
165
+
166
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
167
+
168
+ def batch_decode(self, *args, **kwargs):
169
+ """
170
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
171
+ refer to the docstring of this method for more information.
172
+ """
173
+ return self.tokenizer.batch_decode(*args, **kwargs)
174
+
175
+ def decode(self, *args, **kwargs):
176
+ """
177
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
178
+ the docstring of this method for more information.
179
+ """
180
+ return self.tokenizer.decode(*args, **kwargs)
181
+
182
+ @property
183
+ def model_input_names(self):
184
+ tokenizer_input_names = self.tokenizer.model_input_names
185
+ image_processor_input_names = self.image_processor.model_input_names
186
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4da82e1c8624d84fc3120cc25001da7ad348fbcac426cc5c6b50375657c2a86
3
+ size 5401021086
sdxl_decoder_pipe.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modify from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
2
+ import inspect
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from einops import repeat, rearrange
12
+
13
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
14
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
15
+
16
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+
19
+ from diffusers.utils.torch_utils import randn_tensor
20
+ import PIL.Image
21
+
22
+ from diffusers.models.attention_processor import (
23
+ AttnProcessor2_0,
24
+ FusedAttnProcessor2_0,
25
+ XFormersAttnProcessor,
26
+ )
27
+
28
+ from diffusers.utils import (
29
+ USE_PEFT_BACKEND,
30
+ deprecate,
31
+ is_invisible_watermark_available,
32
+ is_torch_xla_available,
33
+ logging,
34
+ replace_example_docstring,
35
+ scale_lora_layers,
36
+ unscale_lora_layers,
37
+ )
38
+
39
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
40
+ from diffusers.loaders import (
41
+ FromSingleFileMixin,
42
+ IPAdapterMixin,
43
+ StableDiffusionXLLoraLoaderMixin,
44
+ TextualInversionLoaderMixin,
45
+ )
46
+
47
+ if is_invisible_watermark_available():
48
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
49
+
50
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
51
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline, \
52
+ retrieve_timesteps, rescale_noise_cfg
53
+
54
+ from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode
55
+
56
+ if is_torch_xla_available():
57
+ import torch_xla.core.xla_model as xm
58
+
59
+ XLA_AVAILABLE = True
60
+ else:
61
+ XLA_AVAILABLE = False
62
+
63
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
64
+
65
+
66
+ @dataclass
67
+ class StableDiffusionXLDecoderPipelineOutput(StableDiffusionXLPipelineOutput):
68
+ images: Union[List[PIL.Image.Image], np.ndarray]
69
+ indices_semantic: Optional[torch.Tensor] = None
70
+ indices_pixel: Optional[torch.Tensor] = None
71
+
72
+
73
+ def expand_dims_like(x, y):
74
+ while x.dim() != y.dim():
75
+ x = x.unsqueeze(-1)
76
+ return x
77
+
78
+
79
+ class AbstractEmbModel(nn.Module):
80
+ def __init__(self):
81
+ super().__init__()
82
+ self._is_trainable = None
83
+ self._ucg_rate = None
84
+ self._input_key = None
85
+
86
+ @property
87
+ def is_trainable(self) -> bool:
88
+ return self._is_trainable
89
+
90
+ @property
91
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
92
+ return self._ucg_rate
93
+
94
+ @property
95
+ def input_key(self) -> str:
96
+ return self._input_key
97
+
98
+ @is_trainable.setter
99
+ def is_trainable(self, value: bool):
100
+ self._is_trainable = value
101
+
102
+ @ucg_rate.setter
103
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
104
+ self._ucg_rate = value
105
+
106
+ @input_key.setter
107
+ def input_key(self, value: str):
108
+ self._input_key = value
109
+
110
+ @is_trainable.deleter
111
+ def is_trainable(self):
112
+ del self._is_trainable
113
+
114
+ @ucg_rate.deleter
115
+ def ucg_rate(self):
116
+ del self._ucg_rate
117
+
118
+ @input_key.deleter
119
+ def input_key(self):
120
+ del self._input_key
121
+
122
+
123
+ class DualViTok2ImageEmbedder(AbstractEmbModel):
124
+ def __init__(
125
+ self,
126
+ image_processor=None,
127
+ vq_model=None,
128
+ device="cuda",
129
+ dtype=torch.float32,
130
+ freeze=True,
131
+ image_size=0,
132
+ resize_factor=1,
133
+ not_bicubic=True,
134
+ return_sequence=False,
135
+ grid_feature_scale=1,
136
+ texture_drop_prob=0,
137
+ semantic_drop_prob=0,
138
+ pixel_channel=32,
139
+ semantic_channel=32,
140
+ ):
141
+ super().__init__()
142
+ vq_model.to(device=device, dtype=dtype)
143
+ vq_model.eval()
144
+
145
+ self.processor = image_processor
146
+
147
+ self.model = vq_model
148
+ self.device = device
149
+ if freeze:
150
+ self.freeze()
151
+
152
+ if image_size > 0:
153
+ preprocessor = [
154
+ Resize(image_size) if not_bicubic else Resize(image_size, interpolation=InterpolationMode.BICUBIC)]
155
+ preprocessor += [
156
+ CenterCrop(image_size),
157
+ ]
158
+ self.preprocessor = Compose(preprocessor)
159
+ self.image_size = image_size
160
+ self.resize_factor = resize_factor
161
+ self.not_bicubic = not_bicubic
162
+ self.return_sequence = return_sequence
163
+ self.grid_feature_scale = grid_feature_scale
164
+ self.texture_drop_prob = texture_drop_prob
165
+ self.semantic_drop_prob = semantic_drop_prob
166
+ self.pixel_channel = pixel_channel
167
+ self.semantic_channel = semantic_channel
168
+
169
+ def freeze(self):
170
+ self.model = self.model.eval()
171
+ for param in self.parameters():
172
+ param.requires_grad = False
173
+
174
+ def vq_encode(self, image):
175
+ if image.ndim == 5:
176
+ assert image.size(1) == 1
177
+ image = image.squeeze(1)
178
+ bs, _, h, w = image.shape
179
+
180
+ if self.image_size > 0:
181
+ image = self.preprocessor(image)
182
+ else:
183
+ assert self.resize_factor > 0
184
+ preprocessor = Resize((int(h * self.resize_factor), int(w * self.resize_factor))) if self.not_bicubic else \
185
+ Resize((int(h * self.resize_factor), int(w * self.resize_factor)),
186
+ interpolation=InterpolationMode.BICUBIC)
187
+ image = preprocessor(image)
188
+
189
+ inputs = dict(image=image)
190
+ inputs = self.model.get_input(inputs)
191
+
192
+ (quant_semantic, diff_semantic, indices_semantic, target_semantic), \
193
+ (quant_pixel, diff_pixel, indices_pixel) = self.model.encode(**inputs)
194
+ return indices_semantic, indices_pixel
195
+
196
+ def vq_encode_code(self, image):
197
+ (quant_semantic, diff_semantic, indices_semantic, target_semantic), \
198
+ (quant_pixel, diff_pixel, indices_pixel) = self.vq_encode(image)
199
+ return indices_semantic, indices_pixel
200
+
201
+ def vq_decode_code(self, indices_semantic, indices_pixel):
202
+ return self.model.decode_code(indices_semantic, indices_pixel)
203
+
204
+ def forward(self, image, return_indices=False):
205
+ if image.ndim == 5:
206
+ assert image.size(1) == 1
207
+ image = image.squeeze(1)
208
+ bs, _, h, w = image.shape
209
+
210
+ if self.image_size > 0:
211
+ image = self.preprocessor(image)
212
+ else:
213
+ assert self.resize_factor > 0
214
+ preprocessor = Resize((int(h * self.resize_factor), int(w * self.resize_factor))) if self.not_bicubic else \
215
+ Resize((int(h * self.resize_factor), int(w * self.resize_factor)),
216
+ interpolation=InterpolationMode.BICUBIC)
217
+ image = preprocessor(image)
218
+
219
+ inputs = dict(image=image)
220
+ inputs = self.model.get_input(inputs)
221
+
222
+ (quant_semantic, diff_semantic, indices_semantic, target_semantic), \
223
+ (quant_pixel, diff_pixel, indices_pixel) = self.model.encode(**inputs)
224
+
225
+ feature = self.model.merge_quants(quant_semantic, quant_pixel)
226
+
227
+ if self.return_sequence:
228
+ feature = rearrange(feature, 'b c h w -> b h w c')
229
+ _, this_h, this_w, _ = feature.shape
230
+ feature = feature.view(bs, this_w * this_w, -1)
231
+ else:
232
+ feature = feature * self.grid_feature_scale
233
+
234
+ if return_indices:
235
+ return feature, indices_semantic, indices_pixel
236
+
237
+ return feature
238
+
239
+ def encode(self, img):
240
+ return self(img)
241
+
242
+ def indices_to_codes(self, semantic_indices, texture_indices):
243
+ quant_semantic, quant_texture = self.model.indices_to_codes(semantic_indices, texture_indices)
244
+ return self.model.merge_quants(quant_semantic, quant_texture)
245
+
246
+
247
+ class StableDiffusionXLDecoderPipeline(
248
+ DiffusionPipeline,
249
+ StableDiffusionMixin,
250
+ FromSingleFileMixin,
251
+ StableDiffusionXLLoraLoaderMixin,
252
+ TextualInversionLoaderMixin,
253
+ ):
254
+ model_cpu_offload_seq = "vq_model_embedder->unet->vae"
255
+ _optional_components = [
256
+ "vq_model_embedder",
257
+ ]
258
+ _callback_tensor_inputs = [
259
+ "latents",
260
+ "prompt_embeds",
261
+ "negative_prompt_embeds",
262
+ "add_text_embeds",
263
+ "add_time_ids",
264
+ "negative_pooled_prompt_embeds",
265
+ "negative_add_time_ids",
266
+ ]
267
+
268
+ def __init__(
269
+ self,
270
+ vae: AutoencoderKL,
271
+ unet: UNet2DConditionModel,
272
+ scheduler: KarrasDiffusionSchedulers,
273
+ force_zeros_for_empty_prompt: bool = True,
274
+ add_watermarker: Optional[bool] = None,
275
+ vq_image_processor=None,
276
+ vq_model=None,
277
+ ):
278
+ super().__init__()
279
+
280
+ self.register_modules(
281
+ vae=vae,
282
+ unet=unet,
283
+ scheduler=scheduler,
284
+ )
285
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
286
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
287
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
288
+
289
+ self.default_sample_size = self.unet.config.sample_size
290
+
291
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
292
+
293
+ if add_watermarker:
294
+ self.watermark = StableDiffusionXLWatermarker()
295
+ else:
296
+ self.watermark = None
297
+
298
+ self.empty_prompt_embeds = torch.zeros([1, 77, 2048]).to(device=unet.device, dtype=unet.dtype)
299
+ self.empty_pooled_prompt_embeds = torch.zeros([1, 1280]).to(device=unet.device, dtype=unet.dtype)
300
+ self.dualvitok_channels = vq_model.pixel_channel + vq_model.semantic_channel
301
+
302
+ self.resolution_group = ['(1024, 1024)', '(768, 1024)', '(1024, 768)', '(512, 2048)', '(2048, 512)',
303
+ '(640, 1920)', '(1920, 640)', '(768, 1536)', '(1536, 768)', '(768, 1152)',
304
+ '(1152, 768)', '(512, 512)']
305
+
306
+ embedder_kwargs = dict(image_size=0,
307
+ resize_factor=1,
308
+ return_sequence=False,
309
+ grid_feature_scale=1)
310
+ if isinstance(vq_model, DualViTok2ImageEmbedder):
311
+ self.vq_model_embedder = vq_model
312
+ else:
313
+ self.vq_model_embedder = DualViTok2ImageEmbedder(vq_image_processor, vq_model, **embedder_kwargs)
314
+
315
+ def vq_encode(self, image):
316
+ return self.vq_model_embedder.encode(image)
317
+
318
+ def vq_encode_code(self, image):
319
+ return self.vq_model_embedder.vq_encode_code(image)
320
+
321
+ def vq_decode_code(self, *args, **kwargs):
322
+ return self.vq_model_embedder.vq_decode_code(*args, **kwargs)
323
+
324
+ def indices_to_codes(self, *args, **kwargs):
325
+ return self.vq_model_embedder.indices_to_codes(*args, **kwargs)
326
+
327
+ def _get_add_time_ids(
328
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None,
329
+ resolution_index=None,
330
+ ):
331
+ add_time_ids = [resolution_index] * 6
332
+
333
+ passed_add_embed_dim = (
334
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
335
+ )
336
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
337
+
338
+ if expected_add_embed_dim != passed_add_embed_dim:
339
+ raise ValueError(
340
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
341
+ )
342
+
343
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
344
+ return add_time_ids
345
+
346
+ def check_inputs(
347
+ self,
348
+ height,
349
+ width,
350
+ callback_steps,
351
+ callback_on_step_end_tensor_inputs=None,
352
+ ):
353
+ if height % 8 != 0 or width % 8 != 0:
354
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
355
+
356
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
357
+ raise ValueError(
358
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
359
+ f" {type(callback_steps)}."
360
+ )
361
+
362
+ if callback_on_step_end_tensor_inputs is not None and not all(
363
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
364
+ ):
365
+ raise ValueError(
366
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
367
+ )
368
+
369
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
370
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
371
+ shape = (
372
+ batch_size,
373
+ num_channels_latents,
374
+ int(height) // self.vae_scale_factor,
375
+ int(width) // self.vae_scale_factor,
376
+ )
377
+ if isinstance(generator, list) and len(generator) != batch_size:
378
+ raise ValueError(
379
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
380
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
381
+ )
382
+
383
+ if latents is None:
384
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
385
+ else:
386
+ latents = latents.to(device)
387
+
388
+ # scale the initial noise by the standard deviation required by the scheduler
389
+ latents = latents * self.scheduler.init_noise_sigma
390
+ return latents
391
+
392
+ def upcast_vae(self):
393
+ dtype = self.vae.dtype
394
+ self.vae.to(dtype=torch.float32)
395
+ use_torch_2_0_or_xformers = isinstance(
396
+ self.vae.decoder.mid_block.attentions[0].processor,
397
+ (
398
+ AttnProcessor2_0,
399
+ XFormersAttnProcessor,
400
+ FusedAttnProcessor2_0,
401
+ ),
402
+ )
403
+ # if xformers or torch_2_0 is used attention block does not need
404
+ # to be in float32 which can save lots of memory
405
+ if use_torch_2_0_or_xformers:
406
+ self.vae.post_quant_conv.to(dtype)
407
+ self.vae.decoder.conv_in.to(dtype)
408
+ self.vae.decoder.mid_block.to(dtype)
409
+
410
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
411
+ def get_guidance_scale_embedding(
412
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
413
+ ) -> torch.Tensor:
414
+ """
415
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
416
+
417
+ Args:
418
+ w (`torch.Tensor`):
419
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
420
+ embedding_dim (`int`, *optional*, defaults to 512):
421
+ Dimension of the embeddings to generate.
422
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
423
+ Data type of the generated embeddings.
424
+
425
+ Returns:
426
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
427
+ """
428
+ assert len(w.shape) == 1
429
+ w = w * 1000.0
430
+
431
+ half_dim = embedding_dim // 2
432
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
433
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
434
+ emb = w.to(dtype)[:, None] * emb[None, :]
435
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
436
+ if embedding_dim % 2 == 1: # zero pad
437
+ emb = torch.nn.functional.pad(emb, (0, 1))
438
+ assert emb.shape == (w.shape[0], embedding_dim)
439
+ return emb
440
+
441
+ @property
442
+ def guidance_scale(self):
443
+ return self._guidance_scale
444
+
445
+ @property
446
+ def guidance_rescale(self):
447
+ return self._guidance_rescale
448
+
449
+ @property
450
+ def clip_skip(self):
451
+ return self._clip_skip
452
+
453
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
454
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
455
+ # corresponds to doing no classifier free guidance.
456
+ @property
457
+ def do_classifier_free_guidance(self):
458
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
459
+
460
+ @property
461
+ def cross_attention_kwargs(self):
462
+ return self._cross_attention_kwargs
463
+
464
+ @property
465
+ def denoising_end(self):
466
+ return self._denoising_end
467
+
468
+ @property
469
+ def num_timesteps(self):
470
+ return self._num_timesteps
471
+
472
+ @property
473
+ def interrupt(self):
474
+ return self._interrupt
475
+
476
+ def prepare_extra_step_kwargs(self, generator, eta):
477
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
478
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
479
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
480
+ # and should be between [0, 1]
481
+
482
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
483
+ extra_step_kwargs = {}
484
+ if accepts_eta:
485
+ extra_step_kwargs["eta"] = eta
486
+
487
+ # check if the scheduler accepts generator
488
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
489
+ if accepts_generator:
490
+ extra_step_kwargs["generator"] = generator
491
+ return extra_step_kwargs
492
+
493
+ @torch.no_grad()
494
+ def __call__(
495
+ self,
496
+ vq_indices: Optional[List] = None,
497
+ vq_embeds: Optional[torch.Tensor] = None,
498
+ images: Optional[PipelineImageInput] = None,
499
+ height: Optional[int] = None,
500
+ width: Optional[int] = None,
501
+ num_inference_steps: int = 50,
502
+ timesteps: List[int] = None,
503
+ sigmas: List[float] = None,
504
+ denoising_end: Optional[float] = None,
505
+ guidance_scale: float = 2.0,
506
+ eta: float = 0.0,
507
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
508
+ latents: Optional[torch.Tensor] = None,
509
+ output_type: Optional[str] = "pil",
510
+ return_dict: bool = True,
511
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
512
+ guidance_rescale: float = 0.0,
513
+ original_size: Optional[Tuple[int, int]] = None,
514
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
515
+ target_size: Optional[Tuple[int, int]] = None,
516
+ negative_original_size: Optional[Tuple[int, int]] = None,
517
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
518
+ negative_target_size: Optional[Tuple[int, int]] = None,
519
+ clip_skip: Optional[int] = None,
520
+ callback_on_step_end: Optional[
521
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
522
+ ] = None,
523
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
524
+ **kwargs,
525
+ ):
526
+ r"""
527
+ Function invoked when calling the pipeline for generation.
528
+
529
+ Args:
530
+ vq_indices (`Optional[PipelineImageInput]`, *optional*):
531
+ The VQ indices for semantic and pixel tokens. Should be a tuple of (semantic_indices, pixel_indices).
532
+ images (`Optional[PipelineImageInput]`, *optional*):
533
+ Input images in range [-1, 1] as torch.Tensor with shape (batch_size, channels, height, width).
534
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
535
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
536
+ Anything below 512 pixels won't work well for
537
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
538
+ and checkpoints that are not specifically fine-tuned on low resolutions.
539
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
540
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
541
+ Anything below 512 pixels won't work well for
542
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
543
+ and checkpoints that are not specifically fine-tuned on low resolutions.
544
+ num_inference_steps (`int`, *optional*, defaults to 50):
545
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
546
+ expense of slower inference.
547
+ timesteps (`List[int]`, *optional*):
548
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
549
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
550
+ passed will be used. Must be in descending order.
551
+ sigmas (`List[float]`, *optional*):
552
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
553
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
554
+ will be used.
555
+ denoising_end (`float`, *optional*):
556
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
557
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
558
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
559
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
560
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
561
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
562
+ guidance_scale (`float`, *optional*, defaults to 5.0):
563
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
564
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
565
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
566
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
567
+ usually at the expense of lower image quality.
568
+ eta (`float`, *optional*, defaults to 0.0):
569
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
570
+ [`schedulers.DDIMScheduler`], will be ignored for others.
571
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
572
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
573
+ to make generation deterministic.
574
+ latents (`torch.Tensor`, *optional*):
575
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
576
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
577
+ tensor will ge generated by sampling using the supplied random `generator`.
578
+ output_type (`str`, *optional*, defaults to `"pil"`):
579
+ The output format of the generate image. Choose between
580
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
581
+ return_dict (`bool`, *optional*, defaults to `True`):
582
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
583
+ of a plain tuple.
584
+ cross_attention_kwargs (`dict`, *optional*):
585
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
586
+ `self.processor` in
587
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
588
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
589
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
590
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
591
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
592
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
593
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
594
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
595
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
596
+ explained in section 2.2 of
597
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
598
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
599
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
600
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
601
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
602
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
603
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
604
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
605
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
606
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
607
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
608
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
609
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
610
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
611
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
612
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
613
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
614
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
615
+ `._callback_tensor_inputs` attribute of your pipeline class.
616
+
617
+ Examples:
618
+
619
+ Returns:
620
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
621
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
622
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
623
+ """
624
+
625
+ callback = kwargs.pop("callback", None)
626
+ callback_steps = kwargs.pop("callback_steps", None)
627
+
628
+ if callback is not None:
629
+ deprecate(
630
+ "callback",
631
+ "1.0.0",
632
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
633
+ )
634
+ if callback_steps is not None:
635
+ deprecate(
636
+ "callback_steps",
637
+ "1.0.0",
638
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
639
+ )
640
+
641
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
642
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
643
+
644
+ # 0. Default height and width to unet
645
+ height = height or self.default_sample_size * self.vae_scale_factor
646
+ width = width or self.default_sample_size * self.vae_scale_factor
647
+
648
+ original_size = original_size or (height, width)
649
+ target_size = target_size or (height, width)
650
+
651
+ # 1. Check inputs. Raise error if not correct
652
+ self.check_inputs(
653
+ height,
654
+ width,
655
+ callback_steps,
656
+ callback_on_step_end_tensor_inputs,
657
+ )
658
+
659
+ self._guidance_scale = guidance_scale
660
+ self._guidance_rescale = guidance_rescale
661
+ self._clip_skip = clip_skip
662
+ self._cross_attention_kwargs = cross_attention_kwargs
663
+ self._denoising_end = denoising_end
664
+ self._interrupt = False
665
+
666
+ # 2. encode vq_embeds
667
+ assert images is not None or vq_indices is not None or vq_embeds is not None
668
+ batch_size = len(images) if images is not None else len(vq_indices[0])
669
+
670
+ if images:
671
+ vq_embeds, indices_semantic, indices_pixel = self.vq_model_embedder(images, return_indices=True)
672
+ elif vq_indices:
673
+ indices_semantic, indices_pixel = vq_indices[0], vq_indices[1]
674
+ vq_embeds = self.vq_model_embedder.indices_to_codes(vq_indices[0], vq_indices[1])
675
+ elif vq_embeds:
676
+ if isinstance(vq_embeds, list):
677
+ vq_embeds = self.vq_model_embedder.merge_quants(vq_embeds)
678
+ indices_semantic, indices_pixel = None, None
679
+ else:
680
+ raise ValueError("No valid input provided")
681
+
682
+ device = self._execution_device
683
+
684
+ # 3. Encode input prompt
685
+ lora_scale = (
686
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
687
+ )
688
+
689
+ prompt_embeds = repeat(self.empty_prompt_embeds, '1 l c -> b l c', b=batch_size)
690
+ pooled_prompt_embeds = repeat(self.empty_pooled_prompt_embeds, '1 c -> b c', b=batch_size)
691
+
692
+ negative_prompt_embeds = prompt_embeds
693
+ negative_pooled_prompt_embeds = pooled_prompt_embeds
694
+
695
+ # 4. Prepare timesteps
696
+ timesteps, num_inference_steps = retrieve_timesteps(
697
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
698
+ )
699
+
700
+ # 5. Prepare latent variables
701
+ # num_channels_latents = self.unet.config.in_channels
702
+ num_channels_latents = 4
703
+ latents = self.prepare_latents(
704
+ batch_size,
705
+ num_channels_latents,
706
+ height,
707
+ width,
708
+ prompt_embeds.dtype,
709
+ device,
710
+ generator,
711
+ latents,
712
+ )
713
+
714
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
715
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
716
+
717
+ # 7. Prepare added time ids & embeddings
718
+ add_text_embeds = pooled_prompt_embeds
719
+ text_encoder_projection_dim = 1280
720
+
721
+ resolution = f'({width}, {height})'
722
+ assert resolution in self.resolution_group, f"resolution are not in resolution group. Got {resolution}. Candidates:{self.resolution_group}"
723
+ resolution_index = self.resolution_group.index(resolution)
724
+ # resolution_index = None
725
+
726
+ add_time_ids = self._get_add_time_ids(
727
+ original_size,
728
+ crops_coords_top_left,
729
+ target_size,
730
+ dtype=prompt_embeds.dtype,
731
+ text_encoder_projection_dim=text_encoder_projection_dim,
732
+ resolution_index=resolution_index,
733
+ )
734
+ if negative_original_size is not None and negative_target_size is not None:
735
+ negative_add_time_ids = self._get_add_time_ids(
736
+ negative_original_size,
737
+ negative_crops_coords_top_left,
738
+ negative_target_size,
739
+ dtype=prompt_embeds.dtype,
740
+ text_encoder_projection_dim=text_encoder_projection_dim,
741
+ )
742
+ else:
743
+ negative_add_time_ids = add_time_ids
744
+
745
+ if self.do_classifier_free_guidance:
746
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
747
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
748
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
749
+
750
+ prompt_embeds = prompt_embeds.to(device)
751
+ add_text_embeds = add_text_embeds.to(device)
752
+ add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
753
+
754
+ # 8. Denoising loop
755
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
756
+
757
+ # 8.1 Apply denoising_end
758
+ if (
759
+ self.denoising_end is not None
760
+ and isinstance(self.denoising_end, float)
761
+ and self.denoising_end > 0
762
+ and self.denoising_end < 1
763
+ ):
764
+ discrete_timestep_cutoff = int(
765
+ round(
766
+ self.scheduler.config.num_train_timesteps
767
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
768
+ )
769
+ )
770
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
771
+ timesteps = timesteps[:num_inference_steps]
772
+
773
+ # 9. Optionally get Guidance Scale Embedding
774
+ timestep_cond = None
775
+ if self.unet.config.time_cond_proj_dim is not None:
776
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size)
777
+ timestep_cond = self.get_guidance_scale_embedding(
778
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
779
+ ).to(device=device, dtype=latents.dtype)
780
+
781
+ self._num_timesteps = len(timesteps)
782
+ # with self.progress_bar(total=num_inference_steps) as progress_bar:
783
+ for i, t in enumerate(timesteps):
784
+ if self.interrupt:
785
+ continue
786
+
787
+ # expand the latents if we are doing classifier free guidance
788
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
789
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
790
+
791
+ vq_embeds = vq_embeds.to(latent_model_input) if vq_embeds.size(
792
+ -1) == latent_model_input.size(
793
+ -1) else \
794
+ torch.nn.functional.interpolate(vq_embeds.to(latent_model_input),
795
+ size=latent_model_input.shape[-2:])
796
+ vq_embeds_input = torch.cat([torch.zeros_like(vq_embeds),
797
+ vq_embeds]) if self.do_classifier_free_guidance else vq_embeds
798
+
799
+ # predict the noise residual
800
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
801
+
802
+ latent_model_input = torch.cat([latent_model_input, vq_embeds_input], dim=1)
803
+ noise_pred = self.unet(
804
+ latent_model_input,
805
+ t,
806
+ encoder_hidden_states=prompt_embeds,
807
+ timestep_cond=timestep_cond,
808
+ cross_attention_kwargs=self.cross_attention_kwargs,
809
+ added_cond_kwargs=added_cond_kwargs,
810
+ return_dict=False,
811
+ )[0]
812
+
813
+ # perform guidance
814
+ if self.do_classifier_free_guidance:
815
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
816
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
817
+
818
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
819
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
820
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=self.guidance_rescale)
821
+
822
+ # compute the previous noisy sample x_t -> x_t-1
823
+ latents_dtype = latents.dtype
824
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
825
+ if latents.dtype != latents_dtype:
826
+ if torch.backends.mps.is_available():
827
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
828
+ latents = latents.to(latents_dtype)
829
+
830
+ if callback_on_step_end is not None:
831
+ callback_kwargs = {}
832
+ for k in callback_on_step_end_tensor_inputs:
833
+ callback_kwargs[k] = locals()[k]
834
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
835
+
836
+ latents = callback_outputs.pop("latents", latents)
837
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
838
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
839
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
840
+
841
+ # call the callback, if provided
842
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
843
+ # progress_bar.update()
844
+ if callback is not None and i % callback_steps == 0:
845
+ step_idx = i // getattr(self.scheduler, "order", 1)
846
+ callback(step_idx, t, latents)
847
+
848
+ if XLA_AVAILABLE:
849
+ xm.mark_step()
850
+
851
+ if not output_type == "latent":
852
+ # make sure the VAE is in float32 mode, as it overflows in float16
853
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
854
+
855
+ if needs_upcasting:
856
+ self.upcast_vae()
857
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
858
+ elif latents.dtype != self.vae.dtype:
859
+ if torch.backends.mps.is_available():
860
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
861
+ self.vae = self.vae.to(latents.dtype)
862
+
863
+ # unscale/denormalize the latents
864
+ # denormalize with the mean and std if available and not None
865
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
866
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
867
+ if has_latents_mean and has_latents_std:
868
+ latents_mean = (
869
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
870
+ )
871
+ latents_std = (
872
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
873
+ )
874
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
875
+ else:
876
+ latents = latents / self.vae.config.scaling_factor
877
+
878
+ image = self.vae.decode(latents, return_dict=False)[0]
879
+
880
+ # cast back to fp16 if needed
881
+ if needs_upcasting:
882
+ self.vae.to(dtype=torch.float16)
883
+ else:
884
+ image = latents
885
+
886
+ if not output_type == "latent":
887
+ # apply watermark if available
888
+ if self.watermark is not None:
889
+ image = self.watermark.apply_watermark(image)
890
+
891
+ image = self.image_processor.postprocess(image, output_type=output_type)
892
+
893
+ # Offload all models
894
+ self.maybe_free_model_hooks()
895
+
896
+ if not return_dict:
897
+ return (image,)
898
+
899
+ return StableDiffusionXLDecoderPipelineOutput(images=image,
900
+ indices_semantic=indices_semantic,
901
+ indices_pixel=indices_pixel)