WenshuoLi commited on
Commit
f4ff44d
·
1 Parent(s): 8e9ff09

upload transformers version

Browse files
README.md CHANGED
@@ -1,4 +1,3 @@
1
-
2
  # openPangu-VL-7B
3
  中文 | [English](README_EN.md) | [技术报告](doc/technical_report.pdf)
4
 
@@ -69,9 +68,11 @@ openPangu-VL-7B 是基于昇腾 NPU ,基于openPangu-Embedded-7B-V1.1语言基
69
  | MBPP+ | 68.5 |
70
  | IFEval | 83.0 |
71
 
72
- **注:** 系统prompt为空。一般而言,图片最小分辨率设置为2304\*28\*28能获得最优的测评效果。(OCRBench中的极小图OCR除外,建议设置为不大于64\*28\*28。)具体prompt和分辨率设置参见[技术报告](doc/technical_report.pdf)附录。
73
 
74
  ## 4. 部署和使用
 
 
75
  - 使用vllm-ascend推理框架,参考[[vllm_ascend_for_openpangu_vl_7b](doc/vllm_ascend_for_openpangu_vl_7b.md)]进行服务部署。
76
 
77
  - 完成推理服务部署后,使用此脚本测试是否部署成功。
@@ -79,6 +80,20 @@ openPangu-VL-7B 是基于昇腾 NPU ,基于openPangu-Embedded-7B-V1.1语言基
79
  cd inference/vllm_ascend/examples; python quick_start.py
80
  ```
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  - 更多推理样例和能力展示,请参见`cookbooks`。
83
 
84
  ## 5. 模型许可证
@@ -91,4 +106,4 @@ cd inference/vllm_ascend/examples; python quick_start.py
91
  - 该模型的输出内容不构成任何建议或决策,也不保证生成的内容的真实性、完整性、准确性、及时性、合法性、功能性或实用性。生成的内容不能替代医疗、法律等领域的专业人士回答您的问题。生成的内容仅供参考,不代表华为的任何态度、立场或观点。您需要根据实际情况做出独立判断,华为不承担任何责任。
92
 
93
  ## 7. 反馈
94
- 如果有任何意见和建议,请提交issue或联系[openPangu@huawei.com](url)。
 
 
1
  # openPangu-VL-7B
2
  中文 | [English](README_EN.md) | [技术报告](doc/technical_report.pdf)
3
 
 
68
  | MBPP+ | 68.5 |
69
  | IFEval | 83.0 |
70
 
71
+ **注:** 评测使用**vllm-ascend部署推理,系统prompt为空**。一般而言,图片最小分辨率设置为2304\*28\*28能获得最优的测评效果。(OCRBench中的极小图OCR除外,建议设置为不大于64\*28\*28。)具体prompt和分辨率设置参见[技术报告](doc/technical_report.pdf)附录。
72
 
73
  ## 4. 部署和使用
74
+
75
+ ### vllm-ascend部署(推荐)
76
  - 使用vllm-ascend推理框架,参考[[vllm_ascend_for_openpangu_vl_7b](doc/vllm_ascend_for_openpangu_vl_7b.md)]进行服务部署。
77
 
78
  - 完成推理服务部署后,使用此脚本测试是否部署成功。
 
80
  cd inference/vllm_ascend/examples; python quick_start.py
81
  ```
82
 
83
+ ### 直接推理
84
+ 环境配置:
85
+ - python==3.10
86
+ - CANN==8.1.RC1
87
+ ```bash
88
+ cd inference; pip install -r requirements.txt
89
+ ```
90
+
91
+ 推理:
92
+ ```bash
93
+ cd inference; python generate.py
94
+ ```
95
+
96
+ ### 能力展示
97
  - 更多推理样例和能力展示,请参见`cookbooks`。
98
 
99
  ## 5. 模型许可证
 
106
  - 该模型的输出内容不构成任何建议或决策,也不保证生成的内容的真实性、完整性、准确性、及时性、合法性、功能性或实用性。生成的内容不能替代医疗、法律等领域的专业人士回答您的问题。生成的内容仅供参考,不代表华为的任何态度、立场或观点。您需要根据实际情况做出独立判断,华为不承担任何责任。
107
 
108
  ## 7. 反馈
109
+ 如果有任何意见和建议,请提交issue或联系[openPangu@huawei.com](url)。
README_EN.md CHANGED
@@ -70,9 +70,11 @@ The openPangu-VL-7B is an efficient multimodal model based on the Ascend NPU, tr
70
  | MBPP+ | 68.5 |
71
  | IFEval | 83.0 |
72
 
73
- **Note:** The system prompt is empty. Generally, setting the minimum resolution to 2304\*28\*28 can yield the best evaluation results. (Except for the extremely small image OCR in OCRBench, it is recommended to set the resolution to no more than 64\*28\*28.) Detailed settings for different benchmarks can be found in [Technical Report](doc/technical_report.pdf).
74
 
75
  ## 4. Deployment
 
 
76
  - vllm-ascend:please refer to [[vllm_ascend_for_openpangu_vl_7b](doc/vllm_ascend_for_openpangu_vl_7b_EN.md)] to deploy the inference serving.
77
 
78
  - After finish deploying, you can test the api with the following script.
@@ -80,6 +82,21 @@ The openPangu-VL-7B is an efficient multimodal model based on the Ascend NPU, tr
80
  cd inference/vllm_ascend/examples; python quick_start.py
81
  ```
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  - For more examples and demomstrations of model abilities, please refer to `cookbooks`.
84
 
85
  ## 5. Model License
 
70
  | MBPP+ | 68.5 |
71
  | IFEval | 83.0 |
72
 
73
+ **Note:** The evaluation is conducted with **vllm-ascend deploy** and **the system prompt remains empty**. Generally, setting the minimum resolution to 2304\*28\*28 can yield the best evaluation results. (Except for the extremely small image OCR in OCRBench, it is recommended to set the resolution to no more than 64\*28\*28.) Detailed settings for different benchmarks can be found in [Technical Report](doc/technical_report.pdf).
74
 
75
  ## 4. Deployment
76
+
77
+ ### vllm-ascend deploy (recommended)
78
  - vllm-ascend:please refer to [[vllm_ascend_for_openpangu_vl_7b](doc/vllm_ascend_for_openpangu_vl_7b_EN.md)] to deploy the inference serving.
79
 
80
  - After finish deploying, you can test the api with the following script.
 
82
  cd inference/vllm_ascend/examples; python quick_start.py
83
  ```
84
 
85
+ ### Direct inference
86
+
87
+ Environment:
88
+ - python==3.10
89
+ - CANN==8.1.RC1
90
+ ```bash
91
+ cd inference; pip install -r requirements.txt
92
+ ```
93
+
94
+ Inference:
95
+ ```bash
96
+ cd inference; python generate.py
97
+ ```
98
+
99
+ ### Model abilities
100
  - For more examples and demomstrations of model abilities, please refer to `cookbooks`.
101
 
102
  ## 5. Model License
generation_config.json CHANGED
@@ -5,6 +5,7 @@
5
  "eos_token_id": [
6
  45892
7
  ],
8
- "temperature": 0,
9
- "transformers_version": "4.52.4"
10
- }
 
 
5
  "eos_token_id": [
6
  45892
7
  ],
8
+ "temperature": 0.000001,
9
+ "top_k": 1,
10
+ "transformers_version": "4.53.2"
11
+ }
inference/generate.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor
2
+ from transformers import AutoModelForCausalLM
3
+ from qwen_vl_utils import process_vision_info
4
+
5
+ model_path="../"
6
+
7
+ print(f"LOAD MODEL FROM: {model_path}")
8
+
9
+
10
+
11
+ key_mapping = {
12
+ "^visual": "model.visual",
13
+ r"^model(?!\.(language_model|visual))": "model.language_model",
14
+ }
15
+
16
+
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_path,
19
+ trust_remote_code=True,
20
+ torch_dtype='auto',
21
+ key_mapping=key_mapping).eval().cuda()
22
+
23
+ conversation = [
24
+ {
25
+ "role": "system",
26
+ "content": [
27
+ {"type": "text", "text": "你是华为公司开发的多模态大模型,名字是openPangu-VL-7B。你能够处理文本和视觉模态的输入,并给出文本输出。"},
28
+ ]
29
+ },
30
+ {
31
+ "role": "user",
32
+ "content": [
33
+ {"type": "text", "text": "你好,你是谁?"},
34
+ ]
35
+ }
36
+ ]
37
+
38
+
39
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
40
+ text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
41
+
42
+ image_inputs, video_inputs = process_vision_info(conversation)
43
+
44
+ inputs = processor(
45
+ text=[text],
46
+ images=image_inputs,
47
+ videos=video_inputs,
48
+ padding=False,
49
+ return_tensors="pt",
50
+ )
51
+ inputs = inputs.to(model.device)
52
+
53
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
54
+ generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
55
+ res = processor.batch_decode(
56
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
57
+ )
58
+ print(f"OUTPUT: {res}")
inference/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.5.1
2
+ torch_npu==2.5.1
3
+ transformers==4.53.2
4
+ qwen_vl_utils==0.0.14
inference/vllm_ascend/examples/quick_start.py CHANGED
@@ -86,4 +86,4 @@ def infer_message_with_api(prompt):
86
  return json.loads(response.text)["choices"][0]["message"]["content"]
87
 
88
  res = infer_message_with_api("你好,你是谁?")
89
- print(res)
 
86
  return json.loads(response.text)["choices"][0]["message"]["content"]
87
 
88
  res = infer_message_with_api("你好,你是谁?")
89
+ print(res)
inference/vllm_ascend/examples/start_serving_openpangu_vl_7b.sh CHANGED
@@ -95,4 +95,4 @@ echo ${command} | sed "s/--/\n --/g"
95
  # echo ${command}
96
  ${command} | tee $OUTPUT_TEXT_DIR/inference.log
97
 
98
- wait
 
95
  # echo ${command}
96
  ${command} | tee $OUTPUT_TEXT_DIR/inference.log
97
 
98
+ wait
modeling_openpangu_embedded.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ import torch_npu
28
+ from torch_npu.contrib import transfer_to_npu
29
+
30
+ if "910" in torch.npu.get_device_name():
31
+ NPU_ATTN_INFR = True
32
+ print("[INFO] torch_npu detected. Using NPU fused infer attention.")
33
+ else:
34
+ NPU_ATTN_INFR = False
35
+
36
+ from transformers.activations import ACT2FN
37
+ from transformers.cache_utils import Cache, DynamicCache
38
+ from transformers.generation import GenerationMixin
39
+ from transformers.masking_utils import create_causal_mask
40
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
41
+ from transformers.modeling_layers import GradientCheckpointingLayer
42
+ from transformers.modeling_outputs import (
43
+ BaseModelOutputWithPast,
44
+ CausalLMOutputWithPast,
45
+ SequenceClassifierOutputWithPast,
46
+ )
47
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
48
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
49
+ from transformers.processing_utils import Unpack
50
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
51
+ from transformers.configuration_utils import PretrainedConfig
52
+
53
+
54
+ class PanguEmbeddedConfig(PretrainedConfig):
55
+
56
+ model_type = "PanguEmbedded"
57
+ _auto_class = "AutoConfig"
58
+
59
+ def __init__(
60
+ self,
61
+ vocab_size=153376,
62
+ hidden_size=4096,
63
+ intermediate_size=12800,
64
+ num_hidden_layers=34,
65
+ num_attention_heads=32,
66
+ num_key_value_heads=8,
67
+ hidden_act="silu",
68
+ max_position_embeddings=32768,
69
+ initializer_range=0.02,
70
+ rms_norm_eps=1e-5,
71
+ use_cache=True,
72
+ pad_token_id=0,
73
+ bos_token_id=1,
74
+ eos_token_id=45892,
75
+ tie_word_embeddings=False,
76
+ rope_theta=16000000.0,
77
+ bias=True,
78
+ **kwargs,
79
+ ):
80
+ self.vocab_size = vocab_size
81
+ self.max_position_embeddings = max_position_embeddings
82
+ self.hidden_size = hidden_size
83
+ self.intermediate_size = intermediate_size
84
+ self.num_hidden_layers = num_hidden_layers
85
+ self.num_attention_heads = num_attention_heads
86
+ self.num_key_value_heads = num_key_value_heads
87
+ self.hidden_act = hidden_act
88
+ self.initializer_range = initializer_range
89
+ self.rms_norm_eps = rms_norm_eps
90
+ self.use_cache = use_cache
91
+ self.rope_theta = rope_theta
92
+ self.bias = bias
93
+ super().__init__(
94
+ pad_token_id=pad_token_id,
95
+ bos_token_id=bos_token_id,
96
+ eos_token_id=eos_token_id,
97
+ tie_word_embeddings=tie_word_embeddings,
98
+ **kwargs,
99
+ )
100
+
101
+
102
+ logger = logging.get_logger(__name__)
103
+
104
+
105
+ class PanguEmbeddedRMSNorm(nn.Module):
106
+ def __init__(self, hidden_size, eps=1e-6):
107
+ """
108
+ PanguEmbeddedRMSNorm is equivalent to T5LayerNorm
109
+ """
110
+ super().__init__()
111
+ self.weight = nn.Parameter(torch.ones(hidden_size))
112
+ self.variance_epsilon = eps
113
+
114
+ def forward(self, hidden_states):
115
+ input_dtype = hidden_states.dtype
116
+ hidden_states = hidden_states.to(torch.float32)
117
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
118
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
119
+ return self.weight * hidden_states.to(input_dtype)
120
+
121
+ def extra_repr(self):
122
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
123
+
124
+
125
+ class PanguEmbeddedRotaryEmbedding(nn.Module):
126
+ def __init__(self, config: PanguEmbeddedConfig, device=None):
127
+ super().__init__()
128
+ # BC: "rope_type" was originally "type"
129
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
130
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
131
+ else:
132
+ self.rope_type = "default"
133
+ self.max_seq_len_cached = config.max_position_embeddings
134
+ self.original_max_seq_len = config.max_position_embeddings
135
+
136
+ self.config = config
137
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
138
+
139
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
140
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
141
+ self.original_inv_freq = self.inv_freq
142
+
143
+ @torch.no_grad()
144
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
145
+ def forward(self, x, position_ids):
146
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
147
+ position_ids_expanded = position_ids[:, None, :].float()
148
+
149
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
150
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
151
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
152
+ emb = torch.cat((freqs, freqs), dim=-1)
153
+ cos = emb.cos() * self.attention_scaling
154
+ sin = emb.sin() * self.attention_scaling
155
+
156
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
157
+
158
+
159
+ def rotate_half(x):
160
+ """Rotates half the hidden dims of the input."""
161
+ x1 = x[..., : x.shape[-1] // 2]
162
+ x2 = x[..., x.shape[-1] // 2 :]
163
+ return torch.cat((-x2, x1), dim=-1)
164
+
165
+
166
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
167
+ """Applies Rotary Position Embedding to the query and key tensors.
168
+
169
+ Args:
170
+ q (`torch.Tensor`): The query tensor.
171
+ k (`torch.Tensor`): The key tensor.
172
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
173
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
174
+ position_ids (`torch.Tensor`, *optional*):
175
+ Deprecated and unused.
176
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
177
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
178
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
179
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
180
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
181
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
182
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
183
+ Returns:
184
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
185
+ """
186
+ cos = cos.unsqueeze(unsqueeze_dim)
187
+ sin = sin.unsqueeze(unsqueeze_dim)
188
+ q_embed = (q * cos) + (rotate_half(q) * sin)
189
+ k_embed = (k * cos) + (rotate_half(k) * sin)
190
+ return q_embed, k_embed
191
+
192
+
193
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
194
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors
195
+ (https://qwenlm.github.io/blog/qwen2-vl/).
196
+
197
+ Explanation:
198
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
199
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
200
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
201
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
202
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
203
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
204
+ difference with modern LLMs.
205
+
206
+ Args:
207
+ q (`torch.Tensor`): The query tensor.
208
+ k (`torch.Tensor`): The key tensor.
209
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
210
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
211
+ position_ids (`torch.Tensor`):
212
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
213
+ used to pass offsetted position ids when working with a KV-cache.
214
+ mrope_section(`List(int)`):
215
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
216
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
217
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
218
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
219
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
220
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
221
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
222
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
223
+ Returns:
224
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
225
+ """
226
+ mrope_section = mrope_section * 2
227
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
228
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
229
+
230
+ q_embed = (q * cos) + (rotate_half(q) * sin)
231
+ k_embed = (k * cos) + (rotate_half(k) * sin)
232
+ return q_embed, k_embed
233
+
234
+
235
+ class PanguEmbeddedMLP(nn.Module):
236
+ def __init__(self, config, bias: bool = False):
237
+ super().__init__()
238
+ self.hidden_size = config.hidden_size
239
+ self.intermediate_size = config.intermediate_size
240
+ self.hidden_act = config.hidden_act
241
+ if self.hidden_act == "silu":
242
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
243
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
244
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
245
+ self.act_fn = ACT2FN[config.hidden_act]
246
+
247
+ def forward(self, hidden_state):
248
+ if(self.hidden_act == "silu"):
249
+ x_gate= self.gate_proj(hidden_state)
250
+ x_gate = self.act_fn(x_gate)
251
+ x_up= self.up_proj(hidden_state)
252
+ intermediate_parallel = x_gate * x_up
253
+ else:
254
+ x_up= self.up_proj(hidden_state)
255
+ intermediate_parallel = self.act_fn(x_up)
256
+ x_down = self.down_proj(intermediate_parallel)
257
+ return x_down
258
+
259
+
260
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
261
+ """
262
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
263
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
264
+ """
265
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
266
+ if n_rep == 1:
267
+ return hidden_states
268
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
269
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
270
+
271
+
272
+ def eager_attention_forward(
273
+ module: nn.Module,
274
+ query: torch.Tensor,
275
+ key: torch.Tensor,
276
+ value: torch.Tensor,
277
+ attention_mask: Optional[torch.Tensor],
278
+ scaling: float,
279
+ dropout: float = 0.0,
280
+ **kwargs,
281
+ ):
282
+ key_states = repeat_kv(key, module.num_key_value_groups)
283
+ value_states = repeat_kv(value, module.num_key_value_groups)
284
+
285
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
286
+ if attention_mask is not None:
287
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
288
+ attn_weights = attn_weights + causal_mask
289
+
290
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
291
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
292
+ attn_output = torch.matmul(attn_weights, value_states)
293
+ attn_output = attn_output.transpose(1, 2).contiguous()
294
+
295
+ return attn_output, attn_weights
296
+
297
+
298
+ class PanguEmbeddedAttention(nn.Module):
299
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
300
+
301
+ def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
302
+ super().__init__()
303
+ self.config = config
304
+ self.layer_idx = layer_idx
305
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
306
+ self.num_heads = config.num_attention_heads
307
+ self.num_key_value_heads = config.num_key_value_heads
308
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
309
+ self.scaling = self.head_dim**-0.5
310
+ self.attention_dropout = config.attention_dropout
311
+ self.is_causal = True
312
+
313
+ self.q_proj = nn.Linear(
314
+ config.hidden_size,
315
+ config.num_attention_heads * self.head_dim,
316
+ bias=config.bias,
317
+ )
318
+ self.k_proj = nn.Linear(
319
+ config.hidden_size,
320
+ config.num_key_value_heads * self.head_dim,
321
+ bias=config.bias,
322
+ )
323
+ self.v_proj = nn.Linear(
324
+ config.hidden_size,
325
+ config.num_key_value_heads * self.head_dim,
326
+ bias=config.bias,
327
+ )
328
+ self.o_proj = nn.Linear(
329
+ config.num_attention_heads * self.head_dim,
330
+ config.hidden_size,
331
+ bias=config.bias,
332
+ )
333
+
334
+ def forward(
335
+ self,
336
+ hidden_states: torch.Tensor,
337
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
338
+ attention_mask: Optional[torch.Tensor],
339
+ past_key_value: Optional[Cache] = None,
340
+ cache_position: Optional[torch.LongTensor] = None,
341
+ **kwargs: Unpack[FlashAttentionKwargs],
342
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
343
+ input_shape = hidden_states.shape[:-1]
344
+ hidden_shape = (*input_shape, -1, self.head_dim)
345
+
346
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
347
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
348
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
349
+
350
+ cos, sin = position_embeddings
351
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
352
+
353
+ if past_key_value is not None:
354
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
355
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
356
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
357
+
358
+ attention_interface: Callable = eager_attention_forward
359
+ if self.config._attn_implementation != "eager":
360
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
361
+
362
+ if not self.training and NPU_ATTN_INFR:
363
+ q_len = input_shape[1]
364
+ if attention_mask is not None:
365
+ attention_mask = ~attention_mask.bool()
366
+ elif q_len > 1:
367
+ attention_mask = (
368
+ torch.triu(torch.ones([q_len, q_len]), diagonal=1)
369
+ .bool()
370
+ .unsqueeze(0)
371
+ .unsqueeze(0)
372
+ .to(query_states.device)
373
+ )
374
+
375
+ attn_output, _ = torch_npu.npu_fused_infer_attention_score(
376
+ query_states,
377
+ key_states,
378
+ value_states,
379
+ num_heads=self.num_heads,
380
+ num_key_value_heads=self.num_key_value_heads,
381
+ input_layout="BNSD",
382
+ atten_mask=attention_mask,
383
+ scale=self.scaling,
384
+ )
385
+ attn_output = attn_output.transpose(1, 2)
386
+ attn_weights = None
387
+ else:
388
+ attn_output, attn_weights = attention_interface(
389
+ self,
390
+ query_states,
391
+ key_states,
392
+ value_states,
393
+ attention_mask,
394
+ dropout=0.0 if not self.training else self.attention_dropout,
395
+ scaling=self.scaling,
396
+ **kwargs,
397
+ )
398
+
399
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
400
+ attn_output = self.o_proj(attn_output)
401
+ return attn_output, attn_weights
402
+
403
+
404
+ class PanguEmbeddedDecoderLayer(GradientCheckpointingLayer):
405
+ def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
406
+ super().__init__()
407
+ self.hidden_size = config.hidden_size
408
+ self.self_attn = PanguEmbeddedAttention(config=config, layer_idx=layer_idx)
409
+ self.mlp = PanguEmbeddedMLP(config)
410
+ self.input_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
411
+ self.post_attention_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states: torch.Tensor,
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ position_ids: Optional[torch.LongTensor] = None,
418
+ past_key_value: Optional[Cache] = None,
419
+ output_attentions: Optional[bool] = False,
420
+ use_cache: Optional[bool] = False,
421
+ cache_position: Optional[torch.LongTensor] = None,
422
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
423
+ **kwargs: Unpack[FlashAttentionKwargs],
424
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
425
+ residual = hidden_states
426
+ hidden_states = self.input_layernorm(hidden_states)
427
+
428
+ # Self Attention
429
+ hidden_states, self_attn_weights = self.self_attn(
430
+ hidden_states=hidden_states,
431
+ attention_mask=attention_mask,
432
+ position_ids=position_ids,
433
+ past_key_value=past_key_value,
434
+ output_attentions=output_attentions,
435
+ use_cache=use_cache,
436
+ cache_position=cache_position,
437
+ position_embeddings=position_embeddings,
438
+ **kwargs,
439
+ )
440
+ hidden_states = residual + hidden_states
441
+
442
+ # Fully Connected
443
+ residual = hidden_states
444
+ hidden_states = self.post_attention_layernorm(hidden_states)
445
+ hidden_states = self.mlp(hidden_states)
446
+ hidden_states = residual + hidden_states
447
+
448
+ outputs = (hidden_states,)
449
+ if output_attentions:
450
+ outputs += (self_attn_weights,)
451
+
452
+ return outputs
453
+
454
+
455
+ @auto_docstring
456
+ class PanguEmbeddedPreTrainedModel(PreTrainedModel):
457
+ config_class = PanguEmbeddedConfig
458
+ base_model_prefix = "model"
459
+ supports_gradient_checkpointing = True
460
+ _no_split_modules = ["PanguEmbeddedDecoderLayer"]
461
+ _skip_keys_device_placement = ["past_key_values"]
462
+ _supports_flash_attn_3 = True
463
+ _supports_flash_attn_2 = True
464
+ _supports_sdpa = True
465
+ _supports_flex_attn = True
466
+ _supports_cache_class = True
467
+ _supports_quantized_cache = True
468
+ _supports_static_cache = True
469
+ _supports_attention_backend = True
470
+
471
+ def _init_weights(self, module):
472
+ std = self.config.initializer_range
473
+ if isinstance(module, nn.Linear):
474
+ module.weight.data.normal_(mean=0.0, std=std)
475
+ if module.bias is not None:
476
+ module.bias.data.zero_()
477
+ elif isinstance(module, nn.Embedding):
478
+ module.weight.data.normal_(mean=0.0, std=std)
479
+ if module.padding_idx is not None:
480
+ module.weight.data[module.padding_idx].zero_()
481
+ elif isinstance(module, PanguEmbeddedRMSNorm):
482
+ module.weight.data.fill_(1.0)
483
+
484
+
485
+ @auto_docstring
486
+ class PanguEmbeddedModel(PanguEmbeddedPreTrainedModel):
487
+ def __init__(self, config: PanguEmbeddedConfig):
488
+ super().__init__(config)
489
+ self.padding_idx = config.pad_token_id
490
+ self.vocab_size = config.vocab_size
491
+
492
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
493
+ self.layers = nn.ModuleList(
494
+ [PanguEmbeddedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
495
+ )
496
+ self.norm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
497
+ self.rotary_emb = PanguEmbeddedRotaryEmbedding(config=config)
498
+ self.gradient_checkpointing = False
499
+
500
+ # Initialize weights and apply final processing
501
+ self.post_init()
502
+
503
+ def get_input_embeddings(self):
504
+ return self.embed_tokens
505
+
506
+ def set_input_embeddings(self, value):
507
+ self.embed_tokens = value
508
+
509
+ @can_return_tuple
510
+ @auto_docstring
511
+ def forward(
512
+ self,
513
+ input_ids: Optional[torch.LongTensor] = None,
514
+ attention_mask: Optional[torch.Tensor] = None,
515
+ position_ids: Optional[torch.LongTensor] = None,
516
+ past_key_values: Optional[Cache] = None,
517
+ inputs_embeds: Optional[torch.FloatTensor] = None,
518
+ use_cache: Optional[bool] = None,
519
+ output_attentions: Optional[bool] = None,
520
+ output_hidden_states: Optional[bool] = None,
521
+ cache_position: Optional[torch.LongTensor] = None,
522
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
523
+ ) -> BaseModelOutputWithPast:
524
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
525
+ output_hidden_states = (
526
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
527
+ )
528
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
529
+
530
+ if (input_ids is None) ^ (inputs_embeds is not None):
531
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
532
+
533
+ if self.gradient_checkpointing and self.training and use_cache:
534
+ logger.warning_once(
535
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
536
+ )
537
+ use_cache = False
538
+
539
+ if not isinstance(past_key_values, (type(None), Cache)):
540
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
541
+
542
+ if inputs_embeds is None:
543
+ inputs_embeds = self.embed_tokens(input_ids)
544
+
545
+ if use_cache and past_key_values is None:
546
+ past_key_values = DynamicCache()
547
+
548
+ if cache_position is None:
549
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
550
+ cache_position = torch.arange(
551
+ past_seen_tokens,
552
+ past_seen_tokens + inputs_embeds.shape[1],
553
+ device=inputs_embeds.device,
554
+ )
555
+
556
+ if position_ids is None:
557
+ position_ids = cache_position.unsqueeze(0)
558
+
559
+ causal_mask = create_causal_mask(
560
+ config=self.config,
561
+ input_embeds=inputs_embeds,
562
+ attention_mask=attention_mask,
563
+ cache_position=cache_position,
564
+ past_key_values=past_key_values,
565
+ position_ids=position_ids,
566
+ )
567
+
568
+ hidden_states = inputs_embeds
569
+
570
+ # create position embeddings to be shared across the decoder layers
571
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
572
+
573
+ # decoder layers
574
+ all_hidden_states = () if output_hidden_states else None
575
+ all_self_attns = () if output_attentions else None
576
+
577
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
578
+ if output_hidden_states:
579
+ all_hidden_states += (hidden_states,)
580
+
581
+ layer_outputs = decoder_layer(
582
+ hidden_states,
583
+ attention_mask=causal_mask,
584
+ position_ids=position_ids,
585
+ past_key_value=past_key_values,
586
+ output_attentions=output_attentions,
587
+ use_cache=use_cache,
588
+ cache_position=cache_position,
589
+ position_embeddings=position_embeddings,
590
+ **flash_attn_kwargs,
591
+ )
592
+
593
+ hidden_states = layer_outputs[0]
594
+
595
+ if output_attentions:
596
+ all_self_attns += (layer_outputs[1],)
597
+
598
+ hidden_states = self.norm(hidden_states)
599
+
600
+ # add hidden states from the last decoder layer
601
+ if output_hidden_states:
602
+ all_hidden_states += (hidden_states,)
603
+
604
+ return BaseModelOutputWithPast(
605
+ last_hidden_state=hidden_states,
606
+ past_key_values=past_key_values if use_cache else None,
607
+ hidden_states=all_hidden_states,
608
+ attentions=all_self_attns,
609
+ )
610
+
611
+
612
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
613
+ ...
614
+
615
+
616
+ @auto_docstring
617
+ class PanguEmbeddedForCausalLM(PanguEmbeddedPreTrainedModel, GenerationMixin):
618
+ _tied_weights_keys = ["lm_head.weight"]
619
+ _tp_plan = {"lm_head": "colwise_rep"}
620
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
621
+
622
+ def __init__(self, config):
623
+ super().__init__(config)
624
+ self.model = PanguEmbeddedModel(config)
625
+ self.vocab_size = config.vocab_size
626
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
627
+
628
+ # Initialize weights and apply final processing
629
+ self.post_init()
630
+
631
+ def get_input_embeddings(self):
632
+ return self.model.embed_tokens
633
+
634
+ def set_input_embeddings(self, value):
635
+ self.model.embed_tokens = value
636
+
637
+ def get_output_embeddings(self):
638
+ return self.lm_head
639
+
640
+ def set_output_embeddings(self, new_embeddings):
641
+ self.lm_head = new_embeddings
642
+
643
+ def set_decoder(self, decoder):
644
+ self.model = decoder
645
+
646
+ def get_decoder(self):
647
+ return self.model
648
+
649
+ @can_return_tuple
650
+ @auto_docstring
651
+ def forward(
652
+ self,
653
+ input_ids: Optional[torch.LongTensor] = None,
654
+ attention_mask: Optional[torch.Tensor] = None,
655
+ position_ids: Optional[torch.LongTensor] = None,
656
+ past_key_values: Optional[Cache] = None,
657
+ inputs_embeds: Optional[torch.FloatTensor] = None,
658
+ labels: Optional[torch.LongTensor] = None,
659
+ use_cache: Optional[bool] = None,
660
+ output_attentions: Optional[bool] = None,
661
+ output_hidden_states: Optional[bool] = None,
662
+ cache_position: Optional[torch.LongTensor] = None,
663
+ logits_to_keep: Union[int, torch.Tensor] = 0,
664
+ **kwargs: Unpack[KwargsForCausalLM],
665
+ ) -> CausalLMOutputWithPast:
666
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
667
+ output_hidden_states = (
668
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
669
+ )
670
+
671
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
672
+ outputs: BaseModelOutputWithPast = self.model(
673
+ input_ids=input_ids,
674
+ attention_mask=attention_mask,
675
+ position_ids=position_ids,
676
+ past_key_values=past_key_values,
677
+ inputs_embeds=inputs_embeds,
678
+ use_cache=use_cache,
679
+ output_attentions=output_attentions,
680
+ output_hidden_states=output_hidden_states,
681
+ cache_position=cache_position,
682
+ **kwargs,
683
+ )
684
+
685
+ hidden_states = outputs.last_hidden_state
686
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
687
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
688
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
689
+
690
+ loss = None
691
+ if labels is not None:
692
+ loss = self.loss_function(
693
+ logits=logits,
694
+ labels=labels,
695
+ vocab_size=self.config.vocab_size,
696
+ **kwargs,
697
+ )
698
+
699
+ return CausalLMOutputWithPast(
700
+ loss=loss,
701
+ logits=logits,
702
+ past_key_values=outputs.past_key_values,
703
+ hidden_states=outputs.hidden_states,
704
+ attentions=outputs.attentions,
705
+ )
706
+
707
+
708
+ __all__ = [
709
+ "PanguEmbeddedForCausalLM",
710
+ "PanguEmbeddedModel",
711
+ "PanguEmbeddedPreTrainedModel",
712
+ ]
modeling_openpangu_vl.py ADDED
@@ -0,0 +1,1766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from dataclasses import dataclass
23
+ from typing import Any, Callable, Optional, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torch_npu
29
+ from einops import rearrange
30
+
31
+ from transformers.cache_utils import Cache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from transformers.modeling_layers import GradientCheckpointingLayer
35
+ from transformers.modeling_outputs import ModelOutput
36
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from transformers.processing_utils import Unpack
39
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
40
+
41
+ from .configuration_openpangu_vl import OpenPanguVLConfig as OpenPanguConfig
42
+ from .configuration_openpangu_vl import OpenPanguVLTextConfig, OpenPanguVLVisionConfig
43
+ from .modeling_openpangu_embedded import PanguEmbeddedConfig, PanguEmbeddedMLP, PanguEmbeddedModel, PanguEmbeddedRMSNorm
44
+ from .imageprocessor_openpangu_vl import rescale_and_normalize
45
+ if "910" in torch.npu.get_device_name():
46
+ NPU_ATTN_INFR = True
47
+ print("[INFO] torch_npu detected. Using NPU fused infer attention.")
48
+ else:
49
+ NPU_ATTN_INFR = False
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ class OpenPanguVLMLP(PanguEmbeddedMLP):
56
+ pass
57
+
58
+
59
+ class OpenPanguVisionPatchEmbed(nn.Module):
60
+ def __init__(
61
+ self,
62
+ patch_size: int = 14,
63
+ temporal_patch_size: int = 2,
64
+ in_channels: int = 3,
65
+ embed_dim: int = 1152,
66
+ ) -> None:
67
+ super().__init__()
68
+ self.patch_size = patch_size
69
+ self.temporal_patch_size = temporal_patch_size
70
+ self.in_channels = in_channels
71
+ self.embed_dim = embed_dim
72
+
73
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
74
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
75
+ self.input_size = self.patch_size * self.patch_size * in_channels * self.temporal_patch_size
76
+
77
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
78
+ if hidden_states.shape[-1] != self.input_size:
79
+ hidden_states = torch.cat([hidden_states.reshape(-1, self.patch_size * self.patch_size), \
80
+ hidden_states.reshape(-1, self.patch_size * self.patch_size)], dim=-1).reshape(-1, self.input_size)
81
+ target_dtype = self.proj.weight.dtype
82
+ hidden_states = hidden_states.view(
83
+ -1,
84
+ self.in_channels,
85
+ self.temporal_patch_size,
86
+ self.patch_size,
87
+ self.patch_size,
88
+ )
89
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
90
+ return hidden_states
91
+
92
+
93
+ class OpenPanguVLPatchEmbed(OpenPanguVisionPatchEmbed):
94
+ pass
95
+
96
+
97
+ class OpenPanguVisionRotaryEmbedding(nn.Module):
98
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
99
+ super().__init__()
100
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
101
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
102
+
103
+ def forward(self, seqlen: int) -> torch.Tensor:
104
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
105
+ freqs = torch.outer(seq, self.inv_freq)
106
+ return freqs
107
+
108
+
109
+ class OpenPanguRMSNorm(PanguEmbeddedRMSNorm):
110
+ pass
111
+
112
+
113
+ class OpenPanguVLPatchMerger(nn.Module):
114
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
115
+ super().__init__()
116
+ self.hidden_size = context_dim * (spatial_merge_size**2)
117
+ self.ln_q = OpenPanguRMSNorm(context_dim, eps=1e-6)
118
+ self.mlp = nn.Sequential(
119
+ nn.Linear(self.hidden_size, self.hidden_size),
120
+ nn.GELU(),
121
+ nn.Linear(self.hidden_size, dim),
122
+ )
123
+
124
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
125
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
126
+ return x
127
+
128
+
129
+ def rotate_half(x):
130
+ """Rotates half the hidden dims of the input."""
131
+ x1 = x[..., : x.shape[-1] // 2]
132
+ x2 = x[..., x.shape[-1] // 2 :]
133
+ return torch.cat((-x2, x1), dim=-1)
134
+
135
+
136
+ def apply_rotary_pos_emb_vision(
137
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
138
+ ) -> tuple[torch.Tensor, torch.Tensor]:
139
+ orig_q_dtype = q.dtype
140
+ orig_k_dtype = k.dtype
141
+ q, k = q.float(), k.float()
142
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
143
+ q_embed = (q * cos) + (rotate_half(q) * sin)
144
+ k_embed = (k * cos) + (rotate_half(k) * sin)
145
+ q_embed = q_embed.to(orig_q_dtype)
146
+ k_embed = k_embed.to(orig_k_dtype)
147
+ return q_embed, k_embed
148
+
149
+
150
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
151
+ """
152
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
153
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
154
+ """
155
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
156
+ if n_rep == 1:
157
+ return hidden_states
158
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
159
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
160
+
161
+
162
+ def eager_attention_forward(
163
+ module: nn.Module,
164
+ query: torch.Tensor,
165
+ key: torch.Tensor,
166
+ value: torch.Tensor,
167
+ attention_mask: Optional[torch.Tensor],
168
+ scaling: float,
169
+ dropout: float = 0.0,
170
+ **kwargs,
171
+ ):
172
+ key_states = repeat_kv(key, module.num_key_value_groups)
173
+ value_states = repeat_kv(value, module.num_key_value_groups)
174
+
175
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
176
+ if attention_mask is not None:
177
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
178
+ attn_weights = attn_weights + causal_mask
179
+
180
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
181
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
182
+ attn_output = torch.matmul(attn_weights, value_states)
183
+ attn_output = attn_output.transpose(1, 2).contiguous()
184
+
185
+ return attn_output, attn_weights
186
+
187
+
188
+ class OpenPanguVLVisionAttention(nn.Module):
189
+ def __init__(self, config: OpenPanguVLVisionConfig) -> None:
190
+ super().__init__()
191
+ self.dim = config.hidden_size
192
+ self.num_heads = config.num_heads
193
+ self.head_dim = self.dim // self.num_heads
194
+ self.num_key_value_groups = 1 # needed for eager attention
195
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
196
+ self.proj = nn.Linear(self.dim, self.dim)
197
+ self.scaling = self.head_dim**-0.5
198
+ self.config = config
199
+ self.attention_dropout = 0.0
200
+ self.is_causal = False
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.Tensor,
205
+ cu_seqlens: torch.Tensor,
206
+ rotary_pos_emb: Optional[torch.Tensor] = None,
207
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
208
+ attention_mask: Optional[torch.Tensor] = None,
209
+ **kwargs,
210
+ ) -> torch.Tensor:
211
+ seq_length = hidden_states.shape[0]
212
+ query_states, key_states, value_states = (
213
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
214
+ )
215
+ if position_embeddings is None:
216
+ logger.warning_once(
217
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
218
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
219
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
220
+ "removed and `position_embeddings` will be mandatory."
221
+ )
222
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
223
+ cos = emb.cos()
224
+ sin = emb.sin()
225
+ else:
226
+ cos, sin = position_embeddings
227
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
228
+
229
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
230
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
231
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
232
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
233
+
234
+ attention_interface: Callable = eager_attention_forward
235
+ if self.config._attn_implementation != "eager":
236
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
237
+
238
+ if not self.training and NPU_ATTN_INFR:
239
+ if isinstance(cu_seqlens, torch.Tensor):
240
+ cu_seqlens = cu_seqlens.tolist()
241
+
242
+ q, k, v = [rearrange(x, "b n s d -> (b s) n d") for x in [query_states, key_states, value_states]]
243
+ attn_output = torch_npu.npu_fusion_attention(
244
+ q,
245
+ k,
246
+ v,
247
+ self.num_heads,
248
+ "TND",
249
+ pse=None,
250
+ padding_mask=None,
251
+ atten_mask=None,
252
+ scale=self.scaling,
253
+ pre_tockens=1048576,
254
+ next_tockens=0,
255
+ keep_prob=1.0,
256
+ inner_precise=0,
257
+ sparse_mode=0,
258
+ actual_seq_qlen=cu_seqlens,
259
+ actual_seq_kvlen=cu_seqlens,
260
+ )[0]
261
+ else:
262
+ attn_output, _ = attention_interface(
263
+ self,
264
+ query_states,
265
+ key_states,
266
+ value_states,
267
+ attention_mask=attention_mask,
268
+ dropout=0.0 if not self.training else self.attention_dropout,
269
+ scaling=self.scaling,
270
+ cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
271
+ cu_seq_lens_k=cu_seqlens,
272
+ max_length_q=max_seqlen,
273
+ max_length_k=max_seqlen,
274
+ is_causal=False,
275
+ **kwargs,
276
+ )
277
+
278
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
279
+ attn_output = self.proj(attn_output)
280
+ return attn_output
281
+
282
+
283
+ class OpenPanguVLVisionBlock(GradientCheckpointingLayer):
284
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
285
+ super().__init__()
286
+ self.norm1 = OpenPanguRMSNorm(config.hidden_size, eps=1e-6)
287
+ self.norm2 = OpenPanguRMSNorm(config.hidden_size, eps=1e-6)
288
+ self.attn = OpenPanguVLVisionAttention(config=config)
289
+ self.mlp = OpenPanguVLMLP(config, bias=True)
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ cu_seqlens: torch.Tensor,
295
+ rotary_pos_emb: Optional[torch.Tensor] = None,
296
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ **kwargs,
299
+ ) -> torch.Tensor:
300
+ hidden_states = hidden_states + self.attn(
301
+ self.norm1(hidden_states),
302
+ cu_seqlens=cu_seqlens,
303
+ rotary_pos_emb=rotary_pos_emb,
304
+ position_embeddings=position_embeddings,
305
+ attention_mask=attention_mask,
306
+ **kwargs,
307
+ )
308
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
309
+ return hidden_states
310
+
311
+
312
+ @auto_docstring
313
+ class OpenPanguPreTrainedModel(PreTrainedModel):
314
+ config_class = OpenPanguConfig
315
+ base_model_prefix = "model"
316
+ supports_gradient_checkpointing = True
317
+ _no_split_modules = ["OpenPanguVLDecoderLayer", "OpenPanguVLVisionBlock"]
318
+ _skip_keys_device_placement = "past_key_values"
319
+ _supports_flash_attn_2 = True
320
+ _supports_sdpa = True
321
+ _supports_cache_class = True
322
+ _supports_static_cache = True
323
+ _supports_attention_backend = True
324
+
325
+ def _init_weights(self, module):
326
+ std = self.config.get_text_config().initializer_range
327
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
328
+ module.weight.data.normal_(mean=0.0, std=std)
329
+ if module.bias is not None:
330
+ module.bias.data.zero_()
331
+ elif isinstance(module, nn.Embedding):
332
+ module.weight.data.normal_(mean=0.0, std=std)
333
+ if module.padding_idx is not None:
334
+ module.weight.data[module.padding_idx].zero_()
335
+ elif isinstance(module, OpenPanguRMSNorm):
336
+ module.weight.data.fill_(1.0)
337
+
338
+
339
+ class OpenPanguVisionTransformerPretrainedModel(OpenPanguPreTrainedModel):
340
+ config_class = OpenPanguVLVisionConfig
341
+ _no_split_modules = ["OpenPanguVLVisionBlock"]
342
+
343
+ def __init__(self, config, *inputs, **kwargs) -> None:
344
+ super().__init__(config, *inputs, **kwargs)
345
+ self.spatial_merge_size = config.spatial_merge_size
346
+ self.patch_size = config.patch_size
347
+ self.fullatt_block_indexes = config.fullatt_block_indexes
348
+ self.window_size = config.window_size
349
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
350
+ self.patch_embed = OpenPanguVLPatchEmbed(
351
+ patch_size=config.patch_size,
352
+ temporal_patch_size=config.temporal_patch_size,
353
+ in_channels=config.in_channels,
354
+ embed_dim=config.hidden_size,
355
+ )
356
+ head_dim = config.hidden_size // config.num_heads
357
+ self.rotary_pos_emb = OpenPanguVisionRotaryEmbedding(head_dim // 2)
358
+ self.blocks = nn.ModuleList([OpenPanguVLVisionBlock(config) for _ in range(config.depth)])
359
+ self.select_layer = getattr(config, "mm_unit_vision_select_layer", [-1, -3])
360
+ self.select_index = [config.depth + i for i in self.select_layer]
361
+ self.select_index = self.select_index[::-1]
362
+ self.select_layer = [-1 * (i + 1) for i in range(len(self.select_index))]
363
+ self.merger = nn.ModuleList(
364
+ [
365
+ OpenPanguVLPatchMerger(
366
+ dim=config.out_hidden_size,
367
+ context_dim=config.hidden_size,
368
+ spatial_merge_size=config.spatial_merge_size,
369
+ )
370
+ for i in range(len(self.select_layer))
371
+ ]
372
+ )
373
+ self.gradient_checkpointing = False
374
+ self.take_indices = self.select_index
375
+ self.final_layernorm = OpenPanguRMSNorm(config.hidden_size, eps=1e-6)
376
+
377
+ def rot_pos_emb(self, grid_thw):
378
+ pos_ids = []
379
+ for t, h, w in grid_thw:
380
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
381
+ hpos_ids = hpos_ids.reshape(
382
+ h // self.spatial_merge_size,
383
+ self.spatial_merge_size,
384
+ w // self.spatial_merge_size,
385
+ self.spatial_merge_size,
386
+ )
387
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
388
+ hpos_ids = hpos_ids.flatten()
389
+
390
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
391
+ wpos_ids = wpos_ids.reshape(
392
+ h // self.spatial_merge_size,
393
+ self.spatial_merge_size,
394
+ w // self.spatial_merge_size,
395
+ self.spatial_merge_size,
396
+ )
397
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
398
+ wpos_ids = wpos_ids.flatten()
399
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
400
+ pos_ids = torch.cat(pos_ids, dim=0)
401
+ max_grid_size = grid_thw[:, 1:].max()
402
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
403
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
404
+ return rotary_pos_emb
405
+
406
+ def get_window_index(self, grid_thw):
407
+ window_index: list = []
408
+ cu_window_seqlens: list = [0]
409
+ window_index_id = 0
410
+ vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
411
+
412
+ for grid_t, grid_h, grid_w in grid_thw:
413
+ llm_grid_h, llm_grid_w = (
414
+ grid_h // self.spatial_merge_size,
415
+ grid_w // self.spatial_merge_size,
416
+ )
417
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
418
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
419
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
420
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
421
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
422
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
423
+ index_padded = index_padded.reshape(
424
+ grid_t,
425
+ num_windows_h,
426
+ vit_merger_window_size,
427
+ num_windows_w,
428
+ vit_merger_window_size,
429
+ )
430
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
431
+ grid_t,
432
+ num_windows_h * num_windows_w,
433
+ vit_merger_window_size,
434
+ vit_merger_window_size,
435
+ )
436
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
437
+ index_padded = index_padded.reshape(-1)
438
+ index_new = index_padded[index_padded != -100]
439
+ window_index.append(index_new + window_index_id)
440
+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
441
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
442
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
443
+ window_index = torch.cat(window_index, dim=0)
444
+
445
+ return window_index, cu_window_seqlens
446
+
447
+ def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
448
+ # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
449
+ # NOTE: the created attention masl only approximates the ragged FA2 attention by
450
+ # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
451
+ # blocks. Though it will not be a 100% match for FA2's `varlen` path
452
+ if self.config._attn_implementation == "flash_attention_2":
453
+ return None
454
+
455
+ seq_length = inputs_tensor.shape[0]
456
+ attention_mask = torch.full(
457
+ [1, 1, seq_length, seq_length],
458
+ torch.finfo(inputs_tensor.dtype).min,
459
+ device=inputs_tensor.device,
460
+ dtype=inputs_tensor.dtype,
461
+ )
462
+ for i in range(1, len(cu_seqlens)):
463
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
464
+ return attention_mask
465
+
466
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
467
+ """
468
+ Args:
469
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
470
+ The final hidden states of the model.
471
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
472
+ The temporal, height and width of feature shape of each image in LLM.
473
+
474
+ Returns:
475
+ `torch.Tensor`: hidden_states.
476
+ """
477
+ hidden_states = self.patch_embed(hidden_states)
478
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
479
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
480
+ cu_window_seqlens = torch.tensor(
481
+ cu_window_seqlens,
482
+ device=hidden_states.device,
483
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
484
+ )
485
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
486
+
487
+ seq_len, _ = hidden_states.size()
488
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
489
+ hidden_states = hidden_states[window_index, :, :]
490
+ hidden_states = hidden_states.reshape(seq_len, -1)
491
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
492
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
493
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
494
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
495
+ position_embeddings = (emb.cos(), emb.sin())
496
+
497
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
498
+ dim=0,
499
+ # Select dtype based on the following factors:
500
+ # - FA2 requires that cu_seqlens_q must have dtype int32
501
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
502
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
503
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
504
+ )
505
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
506
+ intermediates = []
507
+ for layer_num, blk in enumerate(self.blocks):
508
+ if layer_num in self.fullatt_block_indexes:
509
+ cu_seqlens_now = cu_seqlens
510
+ else:
511
+ cu_seqlens_now = cu_window_seqlens
512
+
513
+ attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
514
+ hidden_states = blk(
515
+ hidden_states,
516
+ cu_seqlens=cu_seqlens_now,
517
+ position_embeddings=position_embeddings,
518
+ attention_mask=attention_mask,
519
+ **kwargs,
520
+ )
521
+ if layer_num in self.take_indices:
522
+ ln_hs = self.final_layernorm(hidden_states)
523
+ intermediates.append(ln_hs)
524
+
525
+ image_embeddings_list = []
526
+ for idx, sl in enumerate(self.select_layer):
527
+ image_embeddings_list.append(self.merger[idx](intermediates[sl]))
528
+ hidden_states = sum(image_embeddings_list)
529
+
530
+ reverse_indices = torch.argsort(window_index)
531
+ hidden_states = hidden_states[reverse_indices, :]
532
+
533
+ return hidden_states
534
+
535
+
536
+ @dataclass
537
+ @auto_docstring(
538
+ custom_intro="""
539
+ Base class for Llava outputs, with hidden states and attentions.
540
+ """
541
+ )
542
+ class OpenPanguVLModelOutputWithPast(ModelOutput):
543
+ r"""
544
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
545
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
546
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
547
+
548
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
549
+ `past_key_values` input) to speed up sequential decoding.
550
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
551
+ The rope index difference between sequence length and multimodal rope.
552
+ """
553
+
554
+ last_hidden_state: torch.FloatTensor = None
555
+ past_key_values: Optional[list[torch.FloatTensor]] = None
556
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
557
+ attentions: Optional[tuple[torch.FloatTensor]] = None
558
+ rope_deltas: Optional[torch.LongTensor] = None
559
+
560
+
561
+ class OpenPanguVLRotaryEmbedding(nn.Module):
562
+ def __init__(self, config: OpenPanguVLTextConfig, device=None):
563
+ super().__init__()
564
+ # BC: "rope_type" was originally "type"
565
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
566
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
567
+ self.mrope_interleaved = config.rope_scaling.get("mrope_interleaved", False)
568
+ else:
569
+ self.rope_type = "default"
570
+ self.max_seq_len_cached = config.max_position_embeddings
571
+ self.original_max_seq_len = config.max_position_embeddings
572
+
573
+ self.config = config
574
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
575
+
576
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
577
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
578
+ self.original_inv_freq = self.inv_freq
579
+
580
+ mrope_section = config.rope_scaling.get("mrope_section", None)
581
+ self.mrope_section = mrope_section
582
+ if self.mrope_interleaved:
583
+ if not self.mrope_section:
584
+ raise AssertionError("when you use interleave mrope, mrope_section cannot be None.")
585
+
586
+ # Generate interleaved indices
587
+ if len(mrope_section) == 2:
588
+ h_num, w_num = mrope_section[0], mrope_section[1]
589
+ mrope_dim = self.get_mrope_interleaved_id_list(h_num, w_num, 0)
590
+ elif len(mrope_section) == 3:
591
+ t_num, h_num, w_num = mrope_section[0], mrope_section[1], mrope_section[2]
592
+ mrope_dim = self.get_mrope_interleaved_id_list(t_num, h_num, w_num, force_last=True)
593
+ else:
594
+ raise AssertionError("Cannot support the length of mrope section is not 2 or 3.")
595
+ mrope_dim = mrope_dim * 2
596
+ self.mrope_dim = mrope_dim
597
+
598
+ @torch.no_grad()
599
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
600
+ def forward(self, x, position_ids):
601
+ # In contrast to other models, OpenPanguVL has different position ids for the grids
602
+ # So we expand the inv_freq to shape (3, ...)
603
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
604
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
605
+
606
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
607
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
608
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
609
+ emb = torch.cat((freqs, freqs), dim=-1)
610
+ # mrope interleaved
611
+ if self.mrope_interleaved:
612
+ mrope_section_3d = [1] * len(self.mrope_dim)
613
+ mrope_dim = self.mrope_dim
614
+ emb = torch.cat([m[mrope_dim[i]] for i, m in enumerate(emb.split(mrope_section_3d, dim=-1))], dim=-1)
615
+
616
+ cos = emb.cos() * self.attention_scaling
617
+ sin = emb.sin() * self.attention_scaling
618
+ # normal mrope
619
+ if not self.mrope_interleaved and self.mrope_section:
620
+ mrope_section = self.mrope_section * 2
621
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1)
622
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1)
623
+
624
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
625
+
626
+ @staticmethod
627
+ def get_mrope_interleaved_id_list(a: int, b: int, c: int, force_last: bool = False) -> list[int]:
628
+ """
629
+ Generate an interleaved list of indices for multi-modal rotary embedding.
630
+
631
+ Args:
632
+ a: Number of indices for first modality
633
+ b: Number of indices for second modality
634
+ c: Number of indices for third modality
635
+ force_last: Whether to force the last element to be from the first modality
636
+
637
+ Returns:
638
+ List of interleaved indices
639
+ """
640
+ if force_last:
641
+ a -= 1
642
+
643
+ counts = {0: a, 1: b, 2: c}
644
+ placed = dict.fromkeys(counts, 0)
645
+ rem = counts.copy()
646
+ seq: list[int] = []
647
+ last = None
648
+
649
+ total = a + b + c
650
+ for _ in range(total):
651
+ # Candidates: remaining > 0 and ≠ last
652
+ cands = [k for k in rem if rem[k] > 0 and k != last]
653
+ if not cands:
654
+ # If only last remains, relax the condition
655
+ cands = [k for k in rem if rem[k] > 0]
656
+
657
+ # Select the rarest candidate
658
+ try:
659
+ best = min(cands, key=lambda k: (placed[k] / counts[k], k))
660
+ except KeyError:
661
+ best = 0
662
+
663
+ seq.append(best)
664
+ placed[best] += 1
665
+ rem[best] -= 1
666
+ last = best
667
+
668
+ if force_last:
669
+ seq.append(0)
670
+
671
+ return seq
672
+
673
+
674
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
675
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
676
+
677
+ Explanation:
678
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
679
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
680
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
681
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
682
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
683
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
684
+ difference with modern LLMs.
685
+
686
+ Args:
687
+ q (`torch.Tensor`): The query tensor.
688
+ k (`torch.Tensor`): The key tensor.
689
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
690
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
691
+ position_ids (`torch.Tensor`):
692
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
693
+ used to pass offsetted position ids when working with a KV-cache.
694
+ mrope_section(`List(int)`):
695
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
696
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
697
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
698
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
699
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
700
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
701
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
702
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
703
+ Returns:
704
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
705
+ """
706
+ mrope_section = mrope_section * 2
707
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
708
+ unsqueeze_dim
709
+ )
710
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
711
+ unsqueeze_dim
712
+ )
713
+
714
+ q_embed = (q * cos) + (rotate_half(q) * sin)
715
+ k_embed = (k * cos) + (rotate_half(k) * sin)
716
+ return q_embed, k_embed
717
+
718
+
719
+ class OpenPanguVLAttention(nn.Module):
720
+ """
721
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
722
+ and "Generating Long Sequences with Sparse Transformers".
723
+ """
724
+
725
+ def __init__(self, config: OpenPanguVLTextConfig, layer_idx: Optional[int] = None):
726
+ super().__init__()
727
+ self.config = config
728
+ self.layer_idx = layer_idx
729
+ if layer_idx is None:
730
+ logger.warning_once(
731
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
732
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
733
+ "when creating this class."
734
+ )
735
+
736
+ self.hidden_size = config.hidden_size
737
+ self.num_heads = config.num_attention_heads
738
+ self.head_dim = self.hidden_size // self.num_heads
739
+ self.num_key_value_heads = config.num_key_value_heads
740
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
741
+ self.is_causal = True
742
+ self.attention_dropout = config.attention_dropout
743
+ self.rope_scaling = config.rope_scaling
744
+ self.scaling = self.head_dim**-0.5
745
+
746
+ if (self.head_dim * self.num_heads) != self.hidden_size:
747
+ raise ValueError(
748
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
749
+ f" and `num_heads`: {self.num_heads})."
750
+ )
751
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
752
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
753
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
754
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
755
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
756
+ self.rotary_emb = OpenPanguVLRotaryEmbedding(config=config)
757
+
758
+ def forward(
759
+ self,
760
+ hidden_states: torch.Tensor,
761
+ attention_mask: Optional[torch.Tensor] = None,
762
+ position_ids: Optional[torch.LongTensor] = None,
763
+ past_key_value: Optional[Cache] = None,
764
+ output_attentions: bool = False,
765
+ use_cache: bool = False,
766
+ cache_position: Optional[torch.LongTensor] = None,
767
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
768
+ **kwargs: Unpack[FlashAttentionKwargs],
769
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
770
+ bsz, q_len, _ = hidden_states.size()
771
+
772
+ query_states = self.q_proj(hidden_states)
773
+ key_states = self.k_proj(hidden_states)
774
+ value_states = self.v_proj(hidden_states)
775
+
776
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
777
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
778
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
779
+
780
+ cos, sin = position_embeddings
781
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
782
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
783
+ )
784
+
785
+ if past_key_value is not None:
786
+ cache_kwargs = {
787
+ "sin": sin,
788
+ "cos": cos,
789
+ "cache_position": cache_position,
790
+ } # Specific to RoPE models
791
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
792
+
793
+ attention_interface: Callable = eager_attention_forward
794
+ if self.config._attn_implementation != "eager":
795
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
796
+
797
+ attn_output, attn_weights = attention_interface(
798
+ self,
799
+ query_states,
800
+ key_states,
801
+ value_states,
802
+ attention_mask,
803
+ dropout=0.0 if not self.training else self.attention_dropout,
804
+ scaling=self.scaling,
805
+ sliding_window=self.sliding_window,
806
+ **kwargs,
807
+ )
808
+
809
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
810
+ attn_output = self.o_proj(attn_output)
811
+ return attn_output, attn_weights, past_key_value
812
+
813
+
814
+ class OpenPanguVLDecoderLayer(GradientCheckpointingLayer):
815
+ def __init__(self, config: OpenPanguVLTextConfig, layer_idx: int):
816
+ super().__init__()
817
+ self.hidden_size = config.hidden_size
818
+
819
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
820
+ logger.warning_once(
821
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
822
+ "unexpected results may be encountered."
823
+ )
824
+ self.self_attn = OpenPanguVLAttention(config, layer_idx)
825
+ self.mlp = OpenPanguVLMLP(config)
826
+ self.input_layernorm = OpenPanguRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
827
+ self.post_attention_layernorm = OpenPanguRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
828
+ self.attention_type = config.layer_types[layer_idx]
829
+
830
+ def forward(
831
+ self,
832
+ hidden_states: torch.Tensor,
833
+ attention_mask: Optional[torch.Tensor] = None,
834
+ position_ids: Optional[torch.LongTensor] = None,
835
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
836
+ output_attentions: Optional[bool] = False,
837
+ use_cache: Optional[bool] = False,
838
+ cache_position: Optional[torch.LongTensor] = None,
839
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
840
+ **kwargs: Unpack[FlashAttentionKwargs],
841
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
842
+ """
843
+ Args:
844
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
845
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
846
+ `(batch, sequence_length)` where padding elements are indicated by 0.
847
+ output_attentions (`bool`, *optional*):
848
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
849
+ returned tensors for more detail.
850
+ use_cache (`bool`, *optional*):
851
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
852
+ (see `past_key_values`).
853
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
854
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
855
+ Indices depicting the position of the input sequence tokens in the sequence.
856
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
857
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
858
+ with `head_dim` being the embedding dimension of each attention head.
859
+ kwargs (`dict`, *optional*):
860
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
861
+ into the model
862
+ """
863
+
864
+ residual = hidden_states
865
+
866
+ hidden_states = self.input_layernorm(hidden_states)
867
+
868
+ # Self Attention
869
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
870
+ hidden_states=hidden_states,
871
+ attention_mask=attention_mask,
872
+ position_ids=position_ids,
873
+ past_key_value=past_key_value,
874
+ output_attentions=output_attentions,
875
+ use_cache=use_cache,
876
+ cache_position=cache_position,
877
+ position_embeddings=position_embeddings,
878
+ **kwargs,
879
+ )
880
+ hidden_states = residual + hidden_states
881
+
882
+ # Fully Connected
883
+ residual = hidden_states
884
+ hidden_states = self.post_attention_layernorm(hidden_states)
885
+ hidden_states = self.mlp(hidden_states)
886
+ hidden_states = residual + hidden_states
887
+
888
+ outputs = (hidden_states,)
889
+
890
+ if output_attentions:
891
+ outputs += (self_attn_weights,)
892
+
893
+ if use_cache:
894
+ outputs += (present_key_value,)
895
+
896
+ return outputs
897
+
898
+
899
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
900
+
901
+
902
+ class ProjectionSingle(nn.Module):
903
+ def __init__(self, i_hidden_size: int, t_hidden_size: int):
904
+ super().__init__()
905
+ self.act = F.silu
906
+ self.fc1 = nn.Linear(i_hidden_size, t_hidden_size, bias=True) # bias 默认为 True
907
+
908
+ def forward(self, hidden_states):
909
+ x = self.act(hidden_states)
910
+ return self.fc1(x)
911
+
912
+
913
+ @auto_docstring
914
+ class OpenPanguVLTextModel(PanguEmbeddedModel):
915
+ def __init__(self, config: PanguEmbeddedConfig):
916
+ super().__init__(config)
917
+ self.rotary_emb = OpenPanguVLRotaryEmbedding(config=config)
918
+
919
+
920
+ @auto_docstring
921
+ class OpenPanguVLModel(OpenPanguPreTrainedModel):
922
+ base_model_prefix = ""
923
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
924
+ config_class = OpenPanguConfig
925
+ _no_split_modules = ["OpenPanguVLDecoderLayer", "OpenPanguVLVisionBlock"]
926
+
927
+ def __init__(self, config):
928
+ super().__init__(config)
929
+ self.visual = OpenPanguVisionTransformerPretrainedModel._from_config(config.vision_config)
930
+ self.language_model = OpenPanguVLTextModel(config.text_config)
931
+
932
+ self.rope_deltas = None # cache rope_deltas here
933
+
934
+ self.visual.vision_projection = ProjectionSingle(config.vision_config.out_hidden_size, config.hidden_size)
935
+
936
+ # Initialize weights and apply final processing
937
+ self.post_init()
938
+ self._parse_preprocess_params(self.config.vision_config)
939
+
940
+ def _parse_preprocess_params(self, vision_config):
941
+ self.channel = vision_config.in_channels
942
+ self.patch_size = vision_config.patch_size
943
+ from transformers import AutoProcessor
944
+ processor = AutoProcessor.from_pretrained(self.config.name_or_path, trust_remote_code=True, local_files_only=True)
945
+ self.do_rescale = processor.image_processor.do_rescale
946
+ self.rescale_factor = processor.image_processor.rescale_factor
947
+ self.do_normalize = processor.image_processor.do_normalize
948
+ self.image_mean = tuple(processor.image_processor.image_mean)
949
+ self.image_std = tuple(processor.image_processor.image_std)
950
+
951
+ def get_input_embeddings(self):
952
+ return self.language_model.get_input_embeddings()
953
+
954
+ def set_input_embeddings(self, value):
955
+ self.language_model.set_input_embeddings(value)
956
+
957
+ def set_decoder(self, decoder):
958
+ self.language_model = decoder
959
+
960
+ def get_decoder(self):
961
+ return self.language_model
962
+
963
+ def get_rope_index(
964
+ self,
965
+ input_ids: Optional[torch.LongTensor] = None,
966
+ image_grid_thw: Optional[torch.LongTensor] = None,
967
+ video_grid_thw: Optional[torch.LongTensor] = None,
968
+ second_per_grid_ts: Optional[torch.Tensor] = None,
969
+ attention_mask: Optional[torch.Tensor] = None,
970
+ ) -> tuple[torch.Tensor, torch.Tensor]:
971
+ """
972
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
973
+
974
+ Explanation:
975
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
976
+
977
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
978
+ Examples:
979
+ input_ids: [T T T T T], here T is for text.
980
+ temporal position_ids: [0, 1, 2, 3, 4]
981
+ height position_ids: [0, 1, 2, 3, 4]
982
+ width position_ids: [0, 1, 2, 3, 4]
983
+
984
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
985
+ and 1D rotary position embedding for text part.
986
+ Examples:
987
+ Temporal (Time): 3 patches, representing different segments of the video in time.
988
+ Height: 2 patches, dividing each frame vertically.
989
+ Width: 2 patches, dividing each frame horizontally.
990
+ We also have some important parameters:
991
+ fps (Frames Per Second): The video's frame rate, set to 1.
992
+ This means one frame is processed each second.
993
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens"
994
+ are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens
995
+ per second. So each second of the video will be represented with 25 separate time points.
996
+ It essentially defines the temporal granularity.
997
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
998
+ interval: The step size for the temporal position IDs, calculated as
999
+ tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50.
1000
+ This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
1001
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
1002
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
1003
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
1004
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
1005
+ text temporal position_ids: [101, 102, 103, 104, 105]
1006
+ text height position_ids: [101, 102, 103, 104, 105]
1007
+ text width position_ids: [101, 102, 103, 104, 105]
1008
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
1009
+
1010
+ Args:
1011
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1012
+ Indices of input sequence tokens in the vocabulary.
1013
+ Padding will be ignored by default should you provide it.
1014
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1015
+ The temporal, height and width of feature shape of each image in LLM.
1016
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1017
+ The temporal, height and width of feature shape of each video in LLM.
1018
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
1019
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
1020
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1021
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1022
+
1023
+ - 1 for tokens that are **not masked**,
1024
+ - 0 for tokens that are **masked**.
1025
+
1026
+ Returns:
1027
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
1028
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
1029
+ """
1030
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
1031
+ image_token_id = self.config.image_token_id
1032
+ video_token_id = self.config.video_token_id
1033
+ vision_start_token_id = self.config.vision_start_token_id
1034
+ vision_end_token_id = self.config.vision_end_token_id
1035
+ tokens_per_second = getattr(self.config, "tokens_per_second", 1.0)
1036
+ mrope_position_deltas = []
1037
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
1038
+ total_input_ids = input_ids
1039
+ if attention_mask is None:
1040
+ attention_mask = torch.ones_like(total_input_ids)
1041
+ position_ids = torch.ones(
1042
+ 3,
1043
+ input_ids.shape[0],
1044
+ input_ids.shape[1],
1045
+ dtype=input_ids.dtype,
1046
+ device=input_ids.device,
1047
+ )
1048
+ attention_mask = attention_mask.to(total_input_ids.device)
1049
+ for i, input_ids in enumerate(total_input_ids):
1050
+ input_ids = input_ids[attention_mask[i] == 1]
1051
+ input_tokens = input_ids.tolist()
1052
+ src_item = input_tokens
1053
+ video_idx = 0
1054
+ image_idx = 0
1055
+ new_src_item: list[int] = []
1056
+ llm_pos_ids_list: list[torch.Tensor] = []
1057
+
1058
+ idx = 0
1059
+ while idx < len(src_item):
1060
+ new_src_item_len = len(new_src_item)
1061
+ start_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1062
+ if src_item[idx] not in [video_token_id, image_token_id]:
1063
+ new_src_item.append(src_item[idx])
1064
+ llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
1065
+ llm_pos_ids_list.append(llm_pos_ids.to(position_ids.device))
1066
+ elif src_item[idx] == image_token_id:
1067
+ grid_t = image_grid_thw[image_idx][0]
1068
+ grid_hs = image_grid_thw[:, 1]
1069
+ grid_ws = image_grid_thw[:, 2]
1070
+ t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
1071
+ llm_pos_ids = self._get_llm_pos_ids_for_vision(
1072
+ start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
1073
+ )
1074
+ llm_pos_ids_list.append(llm_pos_ids.to(position_ids.device))
1075
+ vision_seqlen = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
1076
+ new_src_item.extend([image_token_id] * vision_seqlen)
1077
+ image_idx += 1
1078
+ else:
1079
+ # src_item[idx] == video_token_id
1080
+ # Get the grid information of the current video
1081
+ T = video_grid_thw[video_idx][0].item()
1082
+ H = video_grid_thw[video_idx][1].item()
1083
+ W = video_grid_thw[video_idx][2].item()
1084
+ llm_H = H // spatial_merge_size
1085
+ llm_W = W // spatial_merge_size
1086
+ tokens_per_frame = llm_H * llm_W
1087
+ # Get timestamps (one t value per frame)
1088
+ t_index_all = (torch.arange(T)).long()
1089
+ # Calculate the current starting position
1090
+ start_pos = llm_pos_ids_list[-1].max().item() + 1 if llm_pos_ids_list else 0
1091
+ current_pos = start_pos
1092
+ # frame by frame processing
1093
+ final_frame_time = T - 1 # Record the order of the last frame
1094
+ for t in range(T):
1095
+ # 1. Calculate the left placeholder position of the first frame, skip
1096
+ if t != 0:
1097
+ new_src_item.append(vision_start_token_id) # For looping, count
1098
+ bot_pos = torch.full((3, 1), current_pos, dtype=torch.long)
1099
+ llm_pos_ids_list.append(bot_pos.to(position_ids.device))
1100
+ current_pos += 1
1101
+ # 2. Video tokens for frame t
1102
+ # Construct a single frame of (t, h, w)
1103
+ grid_h = torch.arange(llm_H).view(-1, 1).expand(-1, llm_W).flatten()
1104
+ grid_w = torch.arange(llm_W).view(1, -1).expand(llm_H, -1).flatten()
1105
+ # Here we don't add current_pos to h/w, just keep the original (t, h, w)
1106
+ frame_pos = torch.stack(
1107
+ [
1108
+ torch.full_like(grid_h, 0, dtype=torch.long), # t
1109
+ grid_h, # h
1110
+ grid_w # w
1111
+ ]
1112
+ ) # shape: (3, tokens_per_frame)
1113
+ frame_pos_with_offset = frame_pos + current_pos # Current frame position offset
1114
+ new_src_item.extend([video_token_id] * tokens_per_frame) # For looping, count
1115
+ llm_pos_ids_list.append(frame_pos_with_offset.to(position_ids.device))
1116
+ current_pos += max(llm_H, llm_W)
1117
+ # 3. Calculate the right placeholder position of the last frame and skip it
1118
+ if t != final_frame_time:
1119
+ new_src_item.append(vision_end_token_id) # For looping, count
1120
+ eot_pos = torch.full((3, 1), current_pos, dtype=torch.long)
1121
+ llm_pos_ids_list.append(eot_pos.to(position_ids.device))
1122
+ current_pos += 1
1123
+ video_idx += 1
1124
+ # move to the next token
1125
+ idx += len(new_src_item) - new_src_item_len
1126
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1127
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1128
+ mrope_position_delta = llm_positions.max() + 1 - len(src_item)
1129
+ mrope_position_deltas.append(mrope_position_delta)
1130
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
1131
+ return position_ids, mrope_position_deltas
1132
+ else:
1133
+ if attention_mask is not None:
1134
+ position_ids = attention_mask.long().cumsum(-1) - 1
1135
+ position_ids.masked_fill_(attention_mask == 0, 1)
1136
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1137
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1138
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1139
+ else:
1140
+ position_ids = (
1141
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1142
+ .view(1, 1, -1)
1143
+ .expand(3, input_ids.shape[0], -1)
1144
+ )
1145
+ mrope_position_deltas = torch.zeros(
1146
+ [input_ids.shape[0], 1],
1147
+ device=input_ids.device,
1148
+ dtype=input_ids.dtype,
1149
+ )
1150
+
1151
+ return position_ids, mrope_position_deltas
1152
+
1153
+ def _get_llm_pos_ids_for_vision(
1154
+ self,
1155
+ start_idx: int,
1156
+ vision_idx: int,
1157
+ spatial_merge_size: int,
1158
+ t_index: list[int],
1159
+ grid_hs: torch.Tensor,
1160
+ grid_ws: torch.Tensor,
1161
+ ) -> torch.Tensor:
1162
+ llm_pos_ids_list = []
1163
+ llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
1164
+ llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
1165
+ h_index = (
1166
+ torch.arange(llm_grid_h)
1167
+ .to(llm_grid_h.device)
1168
+ .view(1, -1, 1)
1169
+ .expand(len(t_index), -1, llm_grid_w)
1170
+ .flatten()
1171
+ )
1172
+ w_index = (
1173
+ torch.arange(llm_grid_w)
1174
+ .to(llm_grid_h.device)
1175
+ .view(1, 1, -1)
1176
+ .expand(len(t_index), llm_grid_h, -1)
1177
+ .flatten()
1178
+ )
1179
+ t_index_tensor = (
1180
+ torch.Tensor(t_index)
1181
+ .to(llm_grid_h.device)
1182
+ .view(-1, 1)
1183
+ .expand(-1, llm_grid_h * llm_grid_w)
1184
+ .long()
1185
+ .flatten()
1186
+ )
1187
+ _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
1188
+ llm_pos_ids_list.append(_llm_pos_ids + start_idx)
1189
+ llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
1190
+ return llm_pos_ids
1191
+
1192
+ def get_video_features(
1193
+ self,
1194
+ pixel_values_videos: torch.FloatTensor,
1195
+ video_grid_thw: Optional[torch.LongTensor] = None,
1196
+ ):
1197
+ """
1198
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
1199
+
1200
+ Args:
1201
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1202
+ The tensors corresponding to the input videos.
1203
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1204
+ The temporal, height and width of feature shape of each video in LLM.
1205
+ """
1206
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
1207
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
1208
+
1209
+ video_embeds = self.visual.vision_projection(video_embeds)
1210
+
1211
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1212
+ video_embeds = torch.split(video_embeds, split_sizes)
1213
+ return video_embeds
1214
+
1215
+ def get_image_features(
1216
+ self,
1217
+ pixel_values: torch.FloatTensor,
1218
+ image_grid_thw: Optional[torch.LongTensor] = None,
1219
+ ):
1220
+ """
1221
+ Encodes images into continuous embeddings that can be forwarded to the language model.
1222
+
1223
+ Args:
1224
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1225
+ The tensors corresponding to the input images.
1226
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1227
+ The temporal, height and width of feature shape of each image in LLM.
1228
+ """
1229
+ pixel_values = pixel_values.type(self.visual.dtype)
1230
+ # rescale and normalize
1231
+ pixel_values = pixel_values.reshape(-1, self.channel, self.patch_size, self.patch_size)
1232
+ pixel_values = rescale_and_normalize(pixel_values, self.do_rescale, self.rescale_factor, self.do_normalize,
1233
+ self.image_mean, self.image_std)
1234
+ pixel_values = pixel_values.reshape(-1, self.channel * self.patch_size * self.patch_size)
1235
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
1236
+
1237
+ image_embeds = self.visual.vision_projection(image_embeds)
1238
+
1239
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1240
+ image_embeds = torch.split(image_embeds, split_sizes)
1241
+ return image_embeds
1242
+
1243
+ @auto_docstring
1244
+ def forward(
1245
+ self,
1246
+ input_ids: torch.LongTensor = None,
1247
+ attention_mask: Optional[torch.Tensor] = None,
1248
+ position_ids: Optional[torch.LongTensor] = None,
1249
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
1250
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1251
+ use_cache: Optional[bool] = None,
1252
+ output_attentions: Optional[bool] = None,
1253
+ output_hidden_states: Optional[bool] = None,
1254
+ return_dict: Optional[bool] = None,
1255
+ pixel_values: Optional[torch.Tensor] = None,
1256
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1257
+ image_grid_thw: Optional[torch.LongTensor] = None,
1258
+ video_grid_thw: Optional[torch.LongTensor] = None,
1259
+ rope_deltas: Optional[torch.LongTensor] = None,
1260
+ cache_position: Optional[torch.LongTensor] = None,
1261
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1262
+ **kwargs: Unpack[KwargsForCausalLM],
1263
+ ) -> Union[tuple, OpenPanguVLModelOutputWithPast]:
1264
+ r"""
1265
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length,
1266
+ num_channels * temporal_size * image_size * image_size)):
1267
+ The tensors corresponding to the input videos. Pixel values can be obtained using
1268
+ [`AutoImageProcessor`]. See [`OpenPanguVLImageProcessor.__call__`] for details. [`OpenPanguVLProcessor`] uses
1269
+ [`OpenPanguVLImageProcessor`] for processing videos.
1270
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1271
+ The temporal, height and width of feature shape of each image in LLM.
1272
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1273
+ The temporal, height and width of feature shape of each video in LLM.
1274
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1275
+ The rope index difference between sequence length and multimodal rope.
1276
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
1277
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
1278
+ """
1279
+
1280
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1281
+ output_hidden_states = (
1282
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1283
+ )
1284
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1285
+
1286
+ if inputs_embeds is None:
1287
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1288
+ if pixel_values is not None:
1289
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw)
1290
+ image_embeds = torch.cat(image_embeds, dim=0)
1291
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
1292
+ n_image_features = image_embeds.shape[0]
1293
+ if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
1294
+ raise ValueError(
1295
+ "Image features and image tokens do not match: "
1296
+ f"tokens: {n_image_tokens}, features {n_image_features}"
1297
+ )
1298
+
1299
+ mask = input_ids == self.config.image_token_id
1300
+ mask_unsqueezed = mask.unsqueeze(-1)
1301
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
1302
+ image_mask = mask_expanded.to(inputs_embeds.device)
1303
+
1304
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
1305
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1306
+
1307
+ if pixel_values_videos is not None:
1308
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
1309
+ video_embeds = torch.cat(video_embeds, dim=0)
1310
+ n_video_tokens = (input_ids == self.config.video_token_id).sum()
1311
+ n_video_features = video_embeds.shape[0]
1312
+ if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
1313
+ raise ValueError(
1314
+ "Video features and video tokens do not match: "
1315
+ f"tokens: {n_video_tokens}, features {n_video_features}"
1316
+ )
1317
+
1318
+ mask = input_ids == self.config.video_token_id
1319
+ mask_unsqueezed = mask.unsqueeze(-1)
1320
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
1321
+ video_mask = mask_expanded.to(inputs_embeds.device)
1322
+
1323
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
1324
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1325
+
1326
+ if position_ids is None:
1327
+ attention_mask_tensor = (
1328
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
1329
+ )
1330
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
1331
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
1332
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
1333
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
1334
+
1335
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1336
+ # When compiling, we can't check tensor values thus we check only input length
1337
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1338
+ # models currently cannot do asssisted decoding
1339
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
1340
+ (input_ids is not None and input_ids.shape[1] != 1)
1341
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1342
+ )
1343
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1344
+ (cache_position is not None and cache_position[0] == 0)
1345
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
1346
+ )
1347
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
1348
+ position_ids, rope_deltas = self.get_rope_index(
1349
+ input_ids,
1350
+ image_grid_thw,
1351
+ video_grid_thw,
1352
+ second_per_grid_ts=second_per_grid_ts,
1353
+ attention_mask=attention_mask_tensor,
1354
+ )
1355
+ self.rope_deltas = rope_deltas
1356
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1357
+ else:
1358
+ batch_size, seq_length, _ = inputs_embeds.shape
1359
+ delta = (
1360
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1361
+ if cache_position is not None
1362
+ else 0
1363
+ )
1364
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1365
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1366
+ if cache_position is not None: # otherwise `deltas` is an int `0`
1367
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1368
+ position_ids = position_ids.add(delta)
1369
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1370
+
1371
+ outputs = self.language_model(
1372
+ input_ids=None,
1373
+ position_ids=position_ids,
1374
+ attention_mask=attention_mask,
1375
+ past_key_values=past_key_values,
1376
+ inputs_embeds=inputs_embeds,
1377
+ use_cache=use_cache,
1378
+ output_attentions=output_attentions,
1379
+ output_hidden_states=output_hidden_states,
1380
+ return_dict=True,
1381
+ cache_position=cache_position,
1382
+ **kwargs,
1383
+ )
1384
+
1385
+ output = OpenPanguVLModelOutputWithPast(
1386
+ last_hidden_state=outputs.last_hidden_state,
1387
+ past_key_values=outputs.past_key_values,
1388
+ hidden_states=outputs.hidden_states,
1389
+ attentions=outputs.attentions,
1390
+ rope_deltas=self.rope_deltas,
1391
+ )
1392
+ return output if return_dict else output.to_tuple()
1393
+
1394
+
1395
+ @dataclass
1396
+ @auto_docstring(
1397
+ custom_intro="""
1398
+ Base class for OpenPanguVL causal language model (or autoregressive) outputs.
1399
+ """
1400
+ )
1401
+ class OpenPanguVLCausalLMOutputWithPast(ModelOutput):
1402
+ r"""
1403
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1404
+ Language modeling loss (for next-token prediction).
1405
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1406
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1407
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1408
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1409
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
1410
+
1411
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1412
+ `past_key_values` input) to speed up sequential decoding.
1413
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1414
+ The rope index difference between sequence length and multimodal rope.
1415
+ """
1416
+
1417
+ loss: Optional[torch.FloatTensor] = None
1418
+ logits: Optional[torch.FloatTensor] = None
1419
+ past_key_values: Optional[list[torch.FloatTensor]] = None
1420
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
1421
+ attentions: Optional[tuple[torch.FloatTensor]] = None
1422
+ rope_deltas: Optional[torch.LongTensor] = None
1423
+
1424
+
1425
+ class OpenPanguVL(OpenPanguPreTrainedModel, GenerationMixin):
1426
+ _checkpoint_conversion_mapping = {
1427
+ "^visual": "model.visual",
1428
+ r"^model(?!\.(language_model|visual))": "model.language_model",
1429
+ }
1430
+ _tied_weights_keys = ["lm_head.weight"]
1431
+
1432
+ def __init__(self, config):
1433
+ super().__init__(config)
1434
+ self.model = OpenPanguVLModel(config)
1435
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1436
+
1437
+ self.post_init()
1438
+
1439
+ def get_input_embeddings(self):
1440
+ return self.model.get_input_embeddings()
1441
+
1442
+ def set_input_embeddings(self, value):
1443
+ self.model.set_input_embeddings(value)
1444
+
1445
+ def get_output_embeddings(self):
1446
+ return self.lm_head
1447
+
1448
+ def set_output_embeddings(self, new_embeddings):
1449
+ self.lm_head = new_embeddings
1450
+
1451
+ def set_decoder(self, decoder):
1452
+ self.model.set_decoder(decoder)
1453
+
1454
+ def get_decoder(self):
1455
+ return self.model.get_decoder()
1456
+
1457
+ def get_video_features(
1458
+ self,
1459
+ pixel_values_videos: torch.FloatTensor,
1460
+ video_grid_thw: Optional[torch.LongTensor] = None,
1461
+ ):
1462
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
1463
+
1464
+ def get_image_features(
1465
+ self,
1466
+ pixel_values: torch.FloatTensor,
1467
+ image_grid_thw: Optional[torch.LongTensor] = None,
1468
+ ):
1469
+ return self.model.get_image_features(pixel_values, image_grid_thw)
1470
+
1471
+ # Make modules available throught conditional class for BC
1472
+ @property
1473
+ def language_model(self):
1474
+ return self.model.language_model
1475
+
1476
+ @property
1477
+ def visual(self):
1478
+ return self.model.visual
1479
+
1480
+ @can_return_tuple
1481
+ @auto_docstring
1482
+ def forward(
1483
+ self,
1484
+ input_ids: torch.LongTensor = None,
1485
+ attention_mask: Optional[torch.Tensor] = None,
1486
+ position_ids: Optional[torch.LongTensor] = None,
1487
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
1488
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1489
+ labels: Optional[torch.LongTensor] = None,
1490
+ use_cache: Optional[bool] = None,
1491
+ output_attentions: Optional[bool] = None,
1492
+ output_hidden_states: Optional[bool] = None,
1493
+ pixel_values: Optional[torch.Tensor] = None,
1494
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1495
+ image_grid_thw: Optional[torch.LongTensor] = None,
1496
+ video_grid_thw: Optional[torch.LongTensor] = None,
1497
+ rope_deltas: Optional[torch.LongTensor] = None,
1498
+ cache_position: Optional[torch.LongTensor] = None,
1499
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1500
+ **kwargs: Unpack[KwargsForCausalLM],
1501
+ ) -> Union[tuple, OpenPanguVLCausalLMOutputWithPast]:
1502
+ r"""
1503
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1504
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1505
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1506
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1507
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length,
1508
+ num_channels * temporal_size * image_size * image_size)):
1509
+ The tensors corresponding to the input videos. Pixel values can be obtained using
1510
+ [`AutoImageProcessor`]. See [`OpenPanguVLImageProcessor.__call__`] for details. [`OpenPanguVLProcessor`] uses
1511
+ [`OpenPanguVLImageProcessor`] for processing videos.
1512
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1513
+ The temporal, height and width of feature shape of each image in LLM.
1514
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1515
+ The temporal, height and width of feature shape of each video in LLM.
1516
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1517
+ The rope index difference between sequence length and multimodal rope.
1518
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
1519
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
1520
+
1521
+ Example:
1522
+
1523
+ ```python
1524
+ >>> from PIL import Image
1525
+ >>> import requests
1526
+ >>> from transformers import AutoProcessor, OpenPanguVLForConditionalGeneration
1527
+
1528
+ >>> model = OpenPanguVLForConditionalGeneration.from_pretrained("Pangu/Pangu_7B_V5_VL_HF_vllm_ascend")
1529
+ >>> processor = AutoProcessor.from_pretrained("Pangu/Pangu_7B_V5_VL_HF_vllm_ascend")
1530
+
1531
+ >>> messages = [
1532
+ {
1533
+ "role": "user",
1534
+ "content": [
1535
+ {"type": "image"},
1536
+ {"type": "text", "text": "What is shown in this image?"},
1537
+ ],
1538
+ },
1539
+ ]
1540
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1541
+
1542
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1543
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
1544
+
1545
+ >>> # Generate
1546
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1547
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1548
+ "The image shows a street scene with a red stop sign in the foreground.
1549
+ In the background, there is a large red gate with Chinese characters ..."
1550
+ ```"""
1551
+
1552
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1553
+ output_hidden_states = (
1554
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1555
+ )
1556
+
1557
+ outputs = self.model(
1558
+ input_ids=input_ids,
1559
+ pixel_values=pixel_values,
1560
+ pixel_values_videos=pixel_values_videos,
1561
+ image_grid_thw=image_grid_thw,
1562
+ video_grid_thw=video_grid_thw,
1563
+ second_per_grid_ts=second_per_grid_ts,
1564
+ position_ids=position_ids,
1565
+ attention_mask=attention_mask,
1566
+ past_key_values=past_key_values,
1567
+ inputs_embeds=inputs_embeds,
1568
+ use_cache=use_cache,
1569
+ output_attentions=output_attentions,
1570
+ output_hidden_states=output_hidden_states,
1571
+ return_dict=True,
1572
+ cache_position=cache_position,
1573
+ **kwargs,
1574
+ )
1575
+
1576
+ hidden_states = outputs[0]
1577
+ logits = self.lm_head(hidden_states)
1578
+
1579
+ loss = None
1580
+ if labels is not None:
1581
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
1582
+
1583
+ return OpenPanguVLCausalLMOutputWithPast(
1584
+ loss=loss,
1585
+ logits=logits,
1586
+ past_key_values=outputs.past_key_values,
1587
+ hidden_states=outputs.hidden_states,
1588
+ attentions=outputs.attentions,
1589
+ rope_deltas=outputs.rope_deltas,
1590
+ )
1591
+
1592
+ def prepare_inputs_for_generation(
1593
+ self,
1594
+ input_ids,
1595
+ past_key_values=None,
1596
+ attention_mask=None,
1597
+ inputs_embeds=None,
1598
+ cache_position=None,
1599
+ position_ids=None,
1600
+ use_cache=True,
1601
+ pixel_values=None,
1602
+ pixel_values_videos=None,
1603
+ image_grid_thw=None,
1604
+ video_grid_thw=None,
1605
+ second_per_grid_ts=None,
1606
+ **kwargs,
1607
+ ):
1608
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1609
+
1610
+ model_inputs = super().prepare_inputs_for_generation(
1611
+ input_ids,
1612
+ past_key_values=past_key_values,
1613
+ attention_mask=attention_mask,
1614
+ inputs_embeds=inputs_embeds,
1615
+ cache_position=cache_position,
1616
+ position_ids=position_ids,
1617
+ pixel_values=pixel_values,
1618
+ pixel_values_videos=pixel_values_videos,
1619
+ image_grid_thw=image_grid_thw,
1620
+ video_grid_thw=video_grid_thw,
1621
+ second_per_grid_ts=second_per_grid_ts,
1622
+ use_cache=use_cache,
1623
+ **kwargs,
1624
+ )
1625
+
1626
+ # OpenPangu-VL position_ids are prepareed with rope_deltas in forward
1627
+ model_inputs["position_ids"] = None
1628
+
1629
+ if cache_position[0] != 0:
1630
+ model_inputs["pixel_values"] = None
1631
+ model_inputs["pixel_values_videos"] = None
1632
+
1633
+ return model_inputs
1634
+
1635
+ def _get_image_nums_and_video_nums(
1636
+ self,
1637
+ input_ids: Optional[torch.LongTensor],
1638
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1639
+ """
1640
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1641
+ These parameters are not passed through the processor to avoid unpredictable impacts
1642
+ from interface modifications.
1643
+
1644
+ Args:
1645
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1646
+ Indices of input sequence tokens in the vocabulary.
1647
+
1648
+ Returns:
1649
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1650
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1651
+ """
1652
+ image_token_id = self.config.image_token_id
1653
+ video_token_id = self.config.video_token_id
1654
+ vision_start_token_id = self.config.vision_start_token_id
1655
+
1656
+ vision_start_mask = input_ids == vision_start_token_id
1657
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
1658
+ image_mask = input_ids == image_token_id
1659
+ video_mask = input_ids == video_token_id
1660
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
1661
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
1662
+
1663
+ return image_nums, video_nums
1664
+
1665
+ def _expand_inputs_for_generation(
1666
+ self,
1667
+ expand_size: int = 1,
1668
+ is_encoder_decoder: bool = False,
1669
+ input_ids: Optional[torch.LongTensor] = None,
1670
+ **model_kwargs,
1671
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1672
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1673
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1674
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1675
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1676
+
1677
+ if expand_size == 1:
1678
+ return input_ids, model_kwargs
1679
+
1680
+ visual_keys = [
1681
+ "pixel_values",
1682
+ "image_grid_thw",
1683
+ "pixel_values_videos",
1684
+ "video_grid_thw",
1685
+ "second_per_grid_ts",
1686
+ ]
1687
+
1688
+ def _expand_dict_for_generation_visual(dict_to_expand):
1689
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1690
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
1691
+ image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
1692
+
1693
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1694
+ samples = torch.split(x, lengths)
1695
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1696
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1697
+ return result
1698
+
1699
+ for key in dict_to_expand:
1700
+ if key == "pixel_values":
1701
+ # split images into samples
1702
+ samples = torch.split(image_grid_thw, list(image_nums))
1703
+ # compute the sequence length of images for each sample
1704
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1705
+ dict_to_expand[key] = _repeat_interleave_samples(
1706
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1707
+ )
1708
+ elif key == "image_grid_thw":
1709
+ # get the num of images for each sample
1710
+ lengths = list(image_nums)
1711
+ dict_to_expand[key] = _repeat_interleave_samples(
1712
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1713
+ )
1714
+ elif key == "pixel_values_videos":
1715
+ samples = torch.split(video_grid_thw, list(video_nums))
1716
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1717
+ dict_to_expand[key] = _repeat_interleave_samples(
1718
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1719
+ )
1720
+ elif key == "video_grid_thw":
1721
+ lengths = list(video_nums)
1722
+ dict_to_expand[key] = _repeat_interleave_samples(
1723
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1724
+ )
1725
+ elif key == "second_per_grid_ts":
1726
+ if not isinstance(dict_to_expand[key], list):
1727
+ raise TypeError(
1728
+ f"Expected value for key '{key}' to be a list,but got {type(dict_to_expand[key])} instead."
1729
+ )
1730
+ tensor = torch.tensor(dict_to_expand[key])
1731
+ lengths = list(video_nums)
1732
+ tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
1733
+ dict_to_expand[key] = tensor.tolist()
1734
+ return dict_to_expand
1735
+
1736
+ def _expand_dict_for_generation(dict_to_expand):
1737
+ for key in dict_to_expand:
1738
+ if key != "cache_position":
1739
+ if (
1740
+ dict_to_expand[key] is not None
1741
+ and isinstance(dict_to_expand[key], torch.Tensor)
1742
+ and key not in visual_keys
1743
+ ):
1744
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1745
+ return dict_to_expand
1746
+
1747
+ # input_ids is required for expanding visual inputs
1748
+ # If input_ids is unavailable, visual inputs will not be used;
1749
+ # therefore, there is no need to expand visual inputs.
1750
+ if input_ids is not None and input_ids.numel() != 0:
1751
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1752
+
1753
+ if input_ids is not None:
1754
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1755
+
1756
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1757
+
1758
+ if is_encoder_decoder:
1759
+ if model_kwargs.get("encoder_outputs") is None:
1760
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1761
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1762
+
1763
+ return input_ids, model_kwargs
1764
+
1765
+
1766
+ __all__ = ["OpenPanguVL", "OpenPanguVLModel", "OpenPanguPreTrainedModel"]