qpqpqpqpqpqp commited on
Commit
36c87b2
·
verified ·
1 Parent(s): f30fdd2

Delete Junk

Browse files
.ipynb_checkpoints/README-checkpoint.md DELETED
@@ -1,122 +0,0 @@
1
- ---
2
- pipeline_tag: image-text-to-text
3
- language:
4
- - multilingual
5
- tags:
6
- - deepseek
7
- - vision-language
8
- - ocr
9
- - custom_code
10
- license: mit
11
- ---
12
- <div align="center">
13
- <img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek AI" />
14
- </div>
15
- <hr>
16
- <div align="center">
17
- <a href="https://www.deepseek.com/" target="_blank">
18
- <img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" />
19
- </a>
20
- <a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR" target="_blank">
21
- <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
22
- </a>
23
-
24
- </div>
25
-
26
- <div align="center">
27
-
28
- <a href="https://discord.gg/Tc7c45Zzu5" target="_blank">
29
- <img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" />
30
- </a>
31
- <a href="https://twitter.com/deepseek_ai" target="_blank">
32
- <img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" />
33
- </a>
34
-
35
- </div>
36
-
37
-
38
-
39
- <p align="center">
40
- <a href="https://github.com/deepseek-ai/DeepSeek-OCR"><b>🌟 Github</b></a> |
41
- <a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR"><b>📥 Model Download</b></a> |
42
- <a href="https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek_OCR_paper.pdf"><b>📄 Paper Link</b></a> |
43
- <a href=""><b>📄 Arxiv Paper Link</b></a> |
44
- </p>
45
- <h2>
46
- <p align="center">
47
- <a href="">DeepSeek-OCR: Contexts Optical Compression</a>
48
- </p>
49
- </h2>
50
- <p align="center">
51
- <img src="assets/fig1.png" style="width: 1000px" align=center>
52
- </p>
53
- <p align="center">
54
- <a href="">Explore the boundaries of visual-text compression.</a>
55
- </p>
56
-
57
- ## Usage
58
- Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8:
59
-
60
- ```
61
- torch==2.6.0
62
- transformers==4.46.3
63
- tokenizers==0.20.3
64
- einops
65
- addict
66
- easydict
67
- pip install flash-attn==2.7.3 --no-build-isolation
68
- ```
69
-
70
- ```python
71
- from transformers import AutoModel, AutoTokenizer
72
- import torch
73
- import os
74
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
75
- model_name = 'deepseek-ai/DeepSeek-OCR'
76
-
77
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
78
- model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
79
- model = model.eval().cuda().to(torch.bfloat16)
80
-
81
- # prompt = "<image>\nFree OCR. "
82
- prompt = "<image>\n<|grounding|>Convert the document to markdown. "
83
- image_file = 'your_image.jpg'
84
- output_path = 'your/output/dir'
85
-
86
- # infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
87
-
88
- # Tiny: base_size = 512, image_size = 512, crop_mode = False
89
- # Small: base_size = 640, image_size = 640, crop_mode = False
90
- # Base: base_size = 1024, image_size = 1024, crop_mode = False
91
- # Large: base_size = 1280, image_size = 1280, crop_mode = False
92
-
93
- # Gundam: base_size = 1024, image_size = 640, crop_mode = True
94
-
95
- res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
96
- ```
97
-
98
- ## vLLM
99
- Refer to [🌟GitHub](https://github.com/deepseek-ai/DeepSeek-OCR/) for guidance on model inference acceleration and PDF processing, etc.<!-- -->
100
-
101
- ## Visualizations
102
- <table>
103
- <tr>
104
- <td><img src="assets/show1.jpg" style="width: 500px"></td>
105
- <td><img src="assets/show2.jpg" style="width: 500px"></td>
106
- </tr>
107
- <tr>
108
- <td><img src="assets/show3.jpg" style="width: 500px"></td>
109
- <td><img src="assets/show4.jpg" style="width: 500px"></td>
110
- </tr>
111
- </table>
112
-
113
-
114
- ## Acknowledgement
115
-
116
- We would like to thank [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [OneChart](https://github.com/LingyvKong/OneChart), [Slow Perception](https://github.com/Ucas-HaoranWei/Slow-Perception) for their valuable models and ideas.
117
-
118
- We also appreciate the benchmarks: [Fox](https://github.com/ucaslcl/Fox), [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
119
-
120
-
121
- ## Citation
122
- Coming soon!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.ipynb_checkpoints/configuration_deepseek_v2-checkpoint.py DELETED
@@ -1,210 +0,0 @@
1
- from transformers.configuration_utils import PretrainedConfig
2
- from transformers.utils import logging
3
-
4
- logger = logging.get_logger(__name__)
5
-
6
- DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
7
- class DeepseekV2Config(PretrainedConfig):
8
- r"""
9
- This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
10
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
11
- defaults will yield a similar configuration to that of the DeepSeek-V2 with multi-latent attention.
12
-
13
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
14
- documentation from [`PretrainedConfig`] for more information.
15
-
16
-
17
- Args:
18
- vocab_size (`int`, *optional*, defaults to 102400):
19
- Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
20
- `inputs_ids` passed when calling [`DeepseekV2Model`]
21
- hidden_size (`int`, *optional*, defaults to 4096):
22
- Dimension of the hidden representations.
23
- intermediate_size (`int`, *optional*, defaults to 11008):
24
- Dimension of the MLP representations.
25
- moe_intermediate_size (`int`, *optional*, defaults to 1407):
26
- Dimension of the MoE representations.
27
- num_hidden_layers (`int`, *optional*, defaults to 32):
28
- Number of hidden layers in the Transformer decoder.
29
- num_attention_heads (`int`, *optional*, defaults to 32):
30
- Number of attention heads for each attention layer in the Transformer decoder.
31
- n_shared_experts (`int`, *optional*, defaults to None):
32
- Number of shared experts, None means dense model.
33
- n_routed_experts (`int`, *optional*, defaults to None):
34
- Number of routed experts, None means dense model.
35
- routed_scaling_factor (`float`, *optional*, defaults to 1.0):
36
- Scaling factor or routed experts.
37
- topk_method (`str`, *optional*, defaults to `gready`):
38
- Topk method used in routed gate.
39
- n_group (`int`, *optional*, defaults to None):
40
- Number of groups for routed experts.
41
- topk_group (`int`, *optional*, defaults to None):
42
- Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
43
- num_experts_per_tok (`int`, *optional*, defaults to None):
44
- Number of selected experts, None means dense model.
45
- moe_layer_freq (`int`, *optional*, defaults to 1):
46
- The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
47
- first_k_dense_replace (`int`, *optional*, defaults to 0):
48
- Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
49
- \--k dense layers--/
50
- norm_topk_prob (`bool`, *optional*, defaults to False):
51
- Whether to normalize the weights of the routed experts.
52
- scoring_func (`str`, *optional*, defaults to 'softmax'):
53
- Method of computing expert weights.
54
- aux_loss_alpha (`float`, *optional*, defaults to 0.001):
55
- Auxiliary loss weight coefficient.
56
- seq_aux = (`bool`, *optional*, defaults to True):
57
- Whether to compute the auxiliary loss for each individual sample.
58
- num_key_value_heads (`int`, *optional*):
59
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
60
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
61
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
62
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
63
- by meanpooling all the original heads within that group. For more details checkout [this
64
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
65
- `num_attention_heads`.
66
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
67
- The non-linear activation function (function or string) in the decoder.
68
- max_position_embeddings (`int`, *optional*, defaults to 2048):
69
- The maximum sequence length that this model might ever be used with.
70
- initializer_range (`float`, *optional*, defaults to 0.02):
71
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
73
- The epsilon used by the rms normalization layers.
74
- use_cache (`bool`, *optional*, defaults to `True`):
75
- Whether or not the model should return the last key/values attentions (not used by all models). Only
76
- relevant if `config.is_decoder=True`.
77
- pad_token_id (`int`, *optional*):
78
- Padding token id.
79
- bos_token_id (`int`, *optional*, defaults to 1):
80
- Beginning of stream token id.
81
- eos_token_id (`int`, *optional*, defaults to 2):
82
- End of stream token id.
83
- pretraining_tp (`int`, *optional*, defaults to 1):
84
- Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
85
- document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
86
- necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
87
- issue](https://github.com/pytorch/pytorch/issues/76232).
88
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
89
- Whether to tie weight embeddings
90
- rope_theta (`float`, *optional*, defaults to 10000.0):
91
- The base period of the RoPE embeddings.
92
- rope_scaling (`Dict`, *optional*):
93
- Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
94
- strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
95
- `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
96
- `max_position_embeddings` to the expected new maximum.
97
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
98
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
99
- attention_dropout (`float`, *optional*, defaults to 0.0):
100
- The dropout ratio for the attention probabilities.
101
- use_mla (`bool`, *optional*, defaults to `True`): Use multi-latent attention or multi-head attention. If True,
102
- the model will use multi-latent attention, otherwise, it will use multi-head attention.
103
-
104
- ```python
105
- >>> from transformers import DeepseekV2Model, DeepseekV2Config
106
-
107
- >>> # Initializing a Deepseek-V2 style configuration
108
- >>> configuration = DeepseekV2Config()
109
-
110
- >>> # Accessing the model configuration
111
- >>> configuration = model.config
112
- ```"""
113
-
114
- model_type = "deepseek_v2"
115
- keys_to_ignore_at_inference = ["past_key_values"]
116
-
117
- def __init__(
118
- self,
119
- vocab_size=102400,
120
- hidden_size=4096,
121
- intermediate_size=11008,
122
- moe_intermediate_size = 1407,
123
- num_hidden_layers=30,
124
- num_attention_heads=32,
125
- num_key_value_heads=32,
126
- n_shared_experts = None,
127
- n_routed_experts = None,
128
- ep_size = 1,
129
- routed_scaling_factor = 1.0,
130
- kv_lora_rank = 512,
131
- q_lora_rank = 1536,
132
- qk_rope_head_dim = 64,
133
- v_head_dim = 128,
134
- qk_nope_head_dim = 128,
135
- topk_method = 'gready',
136
- n_group = None,
137
- topk_group = None,
138
- num_experts_per_tok = None,
139
- moe_layer_freq = 1,
140
- first_k_dense_replace = 0,
141
- norm_topk_prob = False,
142
- scoring_func = 'softmax',
143
- aux_loss_alpha = 0.001,
144
- seq_aux = True,
145
- hidden_act="silu",
146
- max_position_embeddings=2048,
147
- initializer_range=0.02,
148
- rms_norm_eps=1e-6,
149
- use_cache=True,
150
- pad_token_id=None,
151
- bos_token_id=100000,
152
- eos_token_id=100001,
153
- pretraining_tp=1,
154
- tie_word_embeddings=False,
155
- rope_theta=10000.0,
156
- rope_scaling=None,
157
- attention_bias=False,
158
- attention_dropout=0.0,
159
- use_mla=True,
160
- **kwargs,
161
- ):
162
- self.vocab_size = vocab_size
163
- self.max_position_embeddings = max_position_embeddings
164
- self.hidden_size = hidden_size
165
- self.intermediate_size = intermediate_size
166
- self.moe_intermediate_size = moe_intermediate_size
167
- self.num_hidden_layers = num_hidden_layers
168
- self.num_attention_heads = num_attention_heads
169
- self.n_shared_experts = n_shared_experts
170
- self.n_routed_experts = n_routed_experts
171
- self.ep_size = ep_size
172
- self.routed_scaling_factor = routed_scaling_factor
173
- self.kv_lora_rank = kv_lora_rank
174
- self.q_lora_rank = q_lora_rank
175
- self.qk_rope_head_dim = qk_rope_head_dim
176
- self.v_head_dim = v_head_dim
177
- self.qk_nope_head_dim = qk_nope_head_dim
178
- self.topk_method = topk_method
179
- self.n_group = n_group
180
- self.topk_group = topk_group
181
- self.num_experts_per_tok = num_experts_per_tok
182
- self.moe_layer_freq = moe_layer_freq
183
- self.first_k_dense_replace = first_k_dense_replace
184
- self.norm_topk_prob = norm_topk_prob
185
- self.scoring_func = scoring_func
186
- self.aux_loss_alpha = aux_loss_alpha
187
- self.seq_aux = seq_aux
188
- # for backward compatibility
189
- if num_key_value_heads is None:
190
- num_key_value_heads = num_attention_heads
191
-
192
- self.num_key_value_heads = num_key_value_heads
193
- self.hidden_act = hidden_act
194
- self.initializer_range = initializer_range
195
- self.rms_norm_eps = float(rms_norm_eps)
196
- self.pretraining_tp = pretraining_tp
197
- self.use_cache = use_cache
198
- self.rope_theta = rope_theta
199
- self.rope_scaling = rope_scaling
200
- self.attention_bias = attention_bias
201
- self.attention_dropout = attention_dropout
202
- self.use_mla = use_mla
203
-
204
- super().__init__(
205
- pad_token_id=pad_token_id,
206
- bos_token_id=bos_token_id,
207
- eos_token_id=eos_token_id,
208
- tie_word_embeddings=tie_word_embeddings,
209
- **kwargs,
210
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.ipynb_checkpoints/conversation-checkpoint.py DELETED
@@ -1,280 +0,0 @@
1
- """
2
- From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
- """
4
-
5
- import dataclasses
6
- from enum import IntEnum, auto
7
- from typing import Any, Dict, List
8
-
9
-
10
- class SeparatorStyle(IntEnum):
11
- """Separator styles."""
12
-
13
- DeepSeek = auto()
14
- DeepSeekV2 = auto()
15
- PLAIN = auto()
16
- ALIGNMENT = auto()
17
-
18
-
19
- @dataclasses.dataclass
20
- class Conversation:
21
- """A class that manages prompt templates and keeps all conversation history."""
22
-
23
- # The name of this template
24
- name: str
25
- # The template of the system prompt
26
- system_template: str = "{system_message}"
27
- # The system message
28
- system_message: str = ""
29
- # The names of two roles
30
- roles: List[str] = (("USER", "ASSISTANT"),)
31
- # All messages. Each item is (role, message).
32
- messages: List[List[str]] = ()
33
- # The number of few shot examples
34
- offset: int = 0
35
- # The separator style and configurations
36
- sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
37
- sep: str = "\n"
38
- sep2: str = None
39
- # Stop criteria (the default one is EOS token)
40
- stop_str: str = None
41
- # Stops generation if meeting any token in this list
42
- stop_token_ids: List[int] = None
43
-
44
- def get_prompt(self) -> str:
45
- """Get the prompt for generation."""
46
- system_prompt = self.system_template.format(system_message=self.system_message)
47
- if self.sep_style == SeparatorStyle.DeepSeek:
48
- seps = [self.sep, self.sep2]
49
- if system_prompt == "" or system_prompt is None:
50
- ret = ""
51
- else:
52
- ret = system_prompt + seps[0]
53
- for i, (role, message) in enumerate(self.messages):
54
- if message:
55
- ret += role + ": " + message + seps[i % 2]
56
- else:
57
- ret += role + ":"
58
- return ret
59
- elif self.sep_style == SeparatorStyle.DeepSeekV2:
60
- seps = [self.sep, self.sep2]
61
- if system_prompt == "" or system_prompt is None:
62
- ret = ""
63
- else:
64
- ret = system_prompt + seps[0]
65
- for i, (role, message) in enumerate(self.messages):
66
- if message:
67
- if role == "User":
68
- ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|>
69
- else:
70
- ret += message + self.sep2
71
- else:
72
- ret = ret
73
- return ret
74
-
75
- elif self.sep_style == SeparatorStyle.PLAIN:
76
- seps = [self.sep, self.sep2]
77
- ret = ""
78
- for i, (role, message) in enumerate(self.messages):
79
- if message:
80
- if type(message) is tuple:
81
- message, _, _ = message
82
- if i % 2 == 0:
83
- ret += message + seps[i % 2]
84
- else:
85
- ret += message + seps[i % 2]
86
- else:
87
- ret += ""
88
- return ret
89
- elif self.sep_style == SeparatorStyle.ALIGNMENT:
90
- seps = [self.sep, self.sep2]
91
- ret = ""
92
- for i, (role, message) in enumerate(self.messages):
93
- if message:
94
- if type(message) is tuple:
95
- message, _, _ = message
96
- if i % 2 == 0:
97
- ret += '<image>\n' + seps[i % 2]
98
- else:
99
- ret += message + seps[i % 2]
100
- else:
101
- ret += ""
102
- return ret
103
- else:
104
- raise ValueError(f"Invalid style: {self.sep_style}")
105
-
106
- def set_system_message(self, system_message: str):
107
- """Set the system message."""
108
- self.system_message = system_message
109
-
110
- def append_message(self, role: str, message: str):
111
- """Append a new message."""
112
- self.messages.append([role, message])
113
-
114
- def update_last_message(self, message: str):
115
- """Update the last output.
116
-
117
- The last message is typically set to be None when constructing the prompt,
118
- so we need to update it in-place after getting the response from a model.
119
- """
120
- self.messages[-1][1] = message
121
-
122
- def reset_message(self):
123
- """Reset a new message."""
124
- self.messages = []
125
-
126
- def to_gradio_chatbot(self):
127
- """Convert the conversation to gradio chatbot format."""
128
- ret = []
129
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
130
- if i % 2 == 0:
131
- ret.append([msg, None])
132
- else:
133
- ret[-1][-1] = msg
134
- return ret
135
-
136
- def to_openai_api_messages(self):
137
- """Convert the conversation to OpenAI chat completion format."""
138
- system_prompt = self.system_template.format(system_message=self.system_message)
139
- ret = [{"role": "system", "content": system_prompt}]
140
-
141
- for i, (_, msg) in enumerate(self.messages[self.offset :]):
142
- if i % 2 == 0:
143
- ret.append({"role": "user", "content": msg})
144
- else:
145
- if msg is not None:
146
- ret.append({"role": "assistant", "content": msg})
147
- return ret
148
-
149
- def copy(self):
150
- return Conversation(
151
- name=self.name,
152
- system_template=self.system_template,
153
- system_message=self.system_message,
154
- roles=self.roles,
155
- messages=[[x, y] for x, y in self.messages],
156
- offset=self.offset,
157
- sep_style=self.sep_style,
158
- sep=self.sep,
159
- sep2=self.sep2,
160
- stop_str=self.stop_str,
161
- stop_token_ids=self.stop_token_ids,
162
- )
163
-
164
- def dict(self):
165
- return {
166
- "template_name": self.name,
167
- "system_message": self.system_message,
168
- "roles": self.roles,
169
- "messages": self.messages,
170
- "offset": self.offset,
171
- }
172
-
173
-
174
- # A global registry for all conversation templates
175
- conv_templates: Dict[str, Conversation] = {}
176
-
177
-
178
- def register_conv_template(template: Conversation, override: bool = False):
179
- """Register a new conversation template."""
180
- if not override:
181
- assert template.name not in conv_templates, f"{template.name} has been registered."
182
-
183
- conv_templates[template.name] = template
184
-
185
-
186
- def get_conv_template(name: str) -> Conversation:
187
- """Get a conversation template."""
188
- return conv_templates[name].copy()
189
-
190
-
191
- register_conv_template(
192
- Conversation(
193
- name="deepseek",
194
- system_template="{system_message}",
195
- # system_message="You are a helpful assistant. Please answer truthfully and write out your "
196
- # "thinking step by step to be sure you get the right answer.",
197
- system_message="",
198
- roles=("<|User|>", "<|Assistant|>"),
199
- messages=(),
200
- offset=0,
201
- sep_style=SeparatorStyle.DeepSeek,
202
- sep="\n\n",
203
- sep2="<|end▁of▁sentence|>",
204
- stop_token_ids=[100001],
205
- stop_str=["User:", "<|end▁of▁sentence|>"]
206
- )
207
- )
208
- register_conv_template(
209
- Conversation(
210
- name="deepseekv2",
211
- system_template="{system_message}",
212
- # system_message="You are a helpful assistant. Please answer truthfully and write out your "
213
- # "thinking step by step to be sure you get the right answer.",
214
- system_message="",
215
- roles=("<|User|>", "<|Assistant|>"),
216
- messages=(),
217
- offset=0,
218
- sep_style=SeparatorStyle.DeepSeek,
219
- sep="",
220
- sep2="<|end▁of▁sentence|>",
221
- stop_token_ids=[100001],
222
- stop_str=["User:", "<|end▁of▁sentence|>"]
223
- )
224
- )
225
-
226
-
227
- register_conv_template(
228
- Conversation(
229
- name="plain",
230
- system_template="",
231
- system_message="",
232
- roles=("", ""),
233
- messages=(),
234
- offset=0,
235
- sep_style=SeparatorStyle.PLAIN,
236
- sep="",
237
- sep2="",
238
- stop_token_ids=[100001],
239
- stop_str=['</s>'],
240
- )
241
- )
242
-
243
-
244
- register_conv_template(
245
- Conversation(
246
- name="alignment",
247
- system_template="",
248
- system_message="",
249
- roles=("", ""),
250
- messages=(),
251
- offset=0,
252
- sep_style=SeparatorStyle.ALIGNMENT,
253
- sep="",
254
- sep2="",
255
- stop_token_ids=[100001],
256
- stop_str=['</s>'],
257
- )
258
- )
259
-
260
-
261
- if __name__ == "__main__":
262
- print("deepseek template:")
263
- conv = get_conv_template("deepseek")
264
- conv.append_message(conv.roles[0], "Hello!")
265
- conv.append_message(conv.roles[1], "Hi! This is Tony.")
266
- conv.append_message(conv.roles[0], "Who are you?")
267
- conv.append_message(conv.roles[1], "I am a helpful assistant.")
268
- conv.append_message(conv.roles[0], "How are you?")
269
- conv.append_message(conv.roles[1], None)
270
- print(conv.get_prompt())
271
-
272
- print("deepseekv2 template:")
273
- conv = get_conv_template("deepseekv2")
274
- conv.append_message(conv.roles[0], "Hello!")
275
- conv.append_message(conv.roles[1], "Hi! This is Tony.")
276
- conv.append_message(conv.roles[0], "Who are you?")
277
- conv.append_message(conv.roles[1], "I am a helpful assistant.")
278
- conv.append_message(conv.roles[0], "How are you?")
279
- conv.append_message(conv.roles[1], None)
280
- print(conv.get_prompt())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.ipynb_checkpoints/deepencoder-checkpoint.py DELETED
@@ -1,1058 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
- import torch.nn.functional as F
4
- import copy
5
-
6
- from contextlib import nullcontext
7
- import math
8
- from typing import Optional, Tuple
9
- # from megatron.model import LayerNorm
10
-
11
- from einops import rearrange
12
- from easydict import EasyDict as adict
13
-
14
-
15
- from typing import Optional, Tuple, Type
16
- from functools import partial
17
-
18
-
19
-
20
- class MlpProjector(nn.Module):
21
-
22
- def __init__(self, cfg):
23
-
24
- super().__init__()
25
-
26
- self.cfg = cfg
27
-
28
- if cfg.projector_type == "identity":
29
- modules = nn.Identity()
30
-
31
- elif cfg.projector_type == "linear":
32
- modules = nn.Linear(cfg.input_dim, cfg.n_embed)
33
-
34
- elif cfg.projector_type == "mlp_gelu":
35
- mlp_depth = cfg.get("depth", 1)
36
- modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
37
- for _ in range(1, mlp_depth):
38
- modules.append(nn.GELU())
39
- modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
40
- modules = nn.Sequential(*modules)
41
-
42
- elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
43
- mlp_depth = cfg.get("depth", 1)
44
- mlp_ratio = cfg.get("mlp_ratio", 1)
45
- modules = [
46
- nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
47
- nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
48
- ]
49
- for _ in range(1, mlp_depth - 1):
50
- modules.append(nn.GELU())
51
- modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
52
- modules.append(nn.GELU())
53
- modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
54
- modules = nn.Sequential(*modules)
55
-
56
- elif cfg.projector_type == "downsample_mlp_gelu":
57
- mlp_depth = cfg.get("depth", 1)
58
- mlp_ratio = cfg.get("mlp_ratio", 1)
59
- modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
60
- for _ in range(1, mlp_depth - 1):
61
- modules.append(nn.GELU())
62
- modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
63
- modules.append(nn.GELU())
64
- modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
65
- modules = nn.Sequential(*modules)
66
-
67
- elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
68
- mlp_depth = cfg.get("depth", 1)
69
- self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
70
- self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
71
-
72
- modules = []
73
- for _ in range(1, mlp_depth):
74
- modules.append(nn.GELU())
75
- modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
76
- modules = nn.Sequential(*modules)
77
-
78
- elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
79
- mlp_depth = cfg.get("depth", 1)
80
- channel_div = cfg.get("channel_div", 0.5)
81
- self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
82
- self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
83
-
84
- modules = []
85
- for _ in range(1, mlp_depth):
86
- modules.append(nn.GELU())
87
- modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
88
- modules = nn.Sequential(*modules)
89
-
90
- elif cfg.projector_type == "low_high_split_mlp_gelu":
91
- mlp_depth = cfg.get("depth", 1)
92
- modules = []
93
- for _ in range(1, mlp_depth):
94
- modules.append(nn.GELU())
95
- modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
96
- modules = nn.Sequential(*modules)
97
- self.high_layers = nn.Sequential(*modules)
98
- self.low_layers = copy.deepcopy(modules)
99
-
100
- else:
101
- raise ValueError(f"Unknown projector type: {cfg.projector_type}")
102
-
103
- if cfg.get("token_pooling", False):
104
- self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
105
-
106
- if cfg.get("conv_fusion_high_low_features", False):
107
- self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
108
- self.layers = modules
109
-
110
- def forward(self, x):
111
- if self.cfg.get("token_pooling", False):
112
- batch_size, wxh, channels = x.shape
113
- w = h = int(wxh**0.5)
114
- x = x.view(batch_size, w, h, channels)
115
- x = x.permute(0, 3, 1, 2)
116
- # import ipdb; ipdb.set_trace()
117
- patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
118
- batch_size, channels, h_patches, w_patches, _, _ = patches.size()
119
- # 在通道维度上拼接
120
- patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
121
-
122
- # 通过线性层
123
- patches = patches.permute(0, 2, 1, 3).contiguous()
124
- patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
125
-
126
- x = self.token_pooling_layer(patches)
127
-
128
- if self.cfg.get("conv_fusion_high_low_features", False):
129
- x = self.fusion_layer(x[:, 0]) + x[:, 1]
130
-
131
- if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
132
- high_x, low_x = x[0], x[1]
133
- high_x = self.high_up_proj(high_x)
134
- low_x = self.low_up_proj(low_x)
135
- x = torch.concat([high_x, low_x], dim=-1)
136
-
137
- if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
138
- high_x = x[...,:self.cfg.input_dim[0]]
139
- low_x = x[...,self.cfg.input_dim[0]:]
140
- high_x = self.high_up_proj(high_x)
141
- low_x = self.low_up_proj(low_x)
142
- x = torch.concat([high_x, low_x], dim=-1)
143
-
144
- if self.cfg.projector_type == 'low_high_split_mlp_gelu':
145
- high_x, low_x = x[0], x[1]
146
- high_x = self.high_layers(high_x)
147
- low_x = self.low_layers(low_x)
148
- x = torch.concat([high_x, low_x], dim=-1)
149
- return x
150
-
151
- if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
152
- bs, hw, input_dim = x.shape
153
- h = w = int((hw) ** 0.5)
154
-
155
- """compute padding"""
156
- if h % self.cfg.downsample_ratio:
157
- pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
158
- else:
159
- pad = 0
160
- x = x.reshape(bs, h, w, input_dim)
161
- if pad > 0:
162
- x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
163
-
164
- """4 to 1 concat"""
165
- x = x.permute(0, 3, 1, 2) # B, C, H, W
166
- x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
167
- x = x.permute(0, 2, 1)
168
-
169
- return self.layers(x)
170
-
171
- @staticmethod
172
- def get_flops_per_sample(cfg):
173
- if cfg.projector_type == "linear":
174
- fwd = 2 * cfg.input_dim * cfg.n_embed
175
-
176
- elif "mlp_gelu" in cfg.projector_type :
177
- mlp_depth = cfg.get("depth", 1)
178
- downsample_ratio = cfg.get("downsample_ratio", 1)
179
- input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
180
- input_dim = input_dim * downsample_ratio * downsample_ratio
181
- fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
182
- else:
183
- fwd = 0
184
-
185
- return fwd * 3
186
-
187
-
188
- #===================clip============================================================
189
-
190
- class LayerNormfp32(torch.nn.LayerNorm):
191
- """Subclass torch's LayerNorm to handle fp16."""
192
-
193
- def forward(self, x: torch.Tensor):
194
- orig_type = x.dtype
195
- ret = super().forward(x.type(torch.float32))
196
- return ret.type(orig_type)
197
-
198
-
199
- def get_abs_pos(abs_pos, tgt_size):
200
- # abs_pos: L, C
201
- # tgt_size: M
202
- # return: M, C
203
-
204
- # print(tgt_size)
205
- # print(abs_pos.shape)
206
- # exit()
207
- dim = abs_pos.size(-1)
208
- # print(dim)
209
- abs_pos_new = abs_pos.squeeze(0)
210
- cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
211
-
212
-
213
-
214
- src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
215
- tgt_size = int(math.sqrt(tgt_size))
216
- dtype = abs_pos.dtype
217
-
218
- if src_size != tgt_size:
219
- old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
220
- 2).contiguous()
221
- old_pos_embed = old_pos_embed.to(torch.float32)
222
- new_pos_embed = F.interpolate(
223
- old_pos_embed,
224
- size=(tgt_size, tgt_size),
225
- mode='bicubic',
226
- antialias=True,
227
- align_corners=False,
228
- ).to(dtype)
229
- new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
230
- new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
231
- vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
232
- vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
233
- return vision_pos_embed
234
- else:
235
- return abs_pos
236
-
237
- @torch.jit.script
238
- def quick_gelu(x):
239
- return x * torch.sigmoid(1.702 * x)
240
-
241
-
242
-
243
- class CLIPVisionEmbeddings(nn.Module):
244
- def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
245
- super().__init__()
246
- self.embed_dim = hidden_size
247
- self.image_size = image_size
248
- self.patch_size = patch_size
249
-
250
- self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
251
-
252
- self.patch_embedding = torch.nn.Conv2d(
253
- in_channels=num_channels,
254
- out_channels=self.embed_dim,
255
- kernel_size=self.patch_size,
256
- stride=self.patch_size,
257
- bias=False,
258
- )
259
-
260
- self.num_patches = (self.image_size // self.patch_size) ** 2
261
- self.num_positions = self.num_patches + 1
262
- self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
263
- self.register_buffer(
264
- "position_ids", torch.arange(self.num_positions).expand((1, -1))
265
- )
266
-
267
- def forward(self, pixel_values, patch_embeds):
268
- batch_size = pixel_values.shape[0]
269
- # patch_embeds = self.patch_embedding(
270
- # pixel_values
271
- # ) # shape = [*, width, grid, grid]
272
-
273
-
274
- if patch_embeds is not None:
275
- patch_embeds = patch_embeds
276
- # print(patch_embeds.shape)
277
- else:
278
- patch_embeds = self.patch_embedding(pixel_values)
279
- # print(111111)
280
- # shape = [*, width, grid, grid]
281
- # patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
282
-
283
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
284
-
285
-
286
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
287
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
288
-
289
- # x = torch.cat([cls_token, x], dim=1)
290
- embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
291
- # embeddings = embeddings + self.position_embedding(self.position_ids)
292
- return embeddings
293
-
294
-
295
- class NoTPFeedForward(nn.Module):
296
- def __init__(
297
- self,
298
- cfg,
299
- dim: int,
300
- hidden_dim: int,
301
- ):
302
- super().__init__()
303
-
304
- self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
305
- self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
306
-
307
- def forward(self, x):
308
- output = self.fc2(quick_gelu(self.fc1(x)))
309
- return output
310
-
311
-
312
-
313
-
314
- class NoTPAttention(torch.nn.Module):
315
- def __init__(self, cfg):
316
- super().__init__()
317
- self.num_heads = cfg.num_attention_heads
318
- self.n_local_heads = cfg.num_attention_heads
319
- self.head_dim = cfg.hidden_size // cfg.num_attention_heads
320
- self.max_seq_len = cfg.seq_length
321
- self.use_flash_attention = cfg.use_flash_attn
322
-
323
- self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
324
- self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
325
-
326
- # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
327
-
328
- self.attn_drop = cfg.attention_dropout
329
-
330
- def forward(
331
- self,
332
- x: torch.Tensor,
333
- ):
334
- bsz, seqlen, _ = x.shape
335
- xqkv = self.qkv_proj(x)
336
- xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
337
-
338
- if self.use_flash_attention:
339
-
340
- xq, xk, xv = torch.split(xqkv, 1, dim=2)
341
- xq = xq.squeeze(2)
342
- xk = xk.squeeze(2)
343
- xv = xv.squeeze(2)
344
- # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
345
-
346
- # (B, num_head, S, head_size)
347
- xq = xq.permute(0, 2, 1, 3)
348
- xk = xk.permute(0, 2, 1, 3)
349
- xv = xv.permute(0, 2, 1, 3)
350
- # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
351
- output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
352
- output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
353
- # output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
354
- else:
355
- # print(22222)
356
- xq, xk, xv = torch.split(xqkv, 1, dim=2)
357
- xq = xq.squeeze(2)
358
- xk = xk.squeeze(2)
359
- xv = xv.squeeze(2)
360
- # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
361
-
362
- # (B, num_head, S, head_size)
363
- xq = xq.permute(0, 2, 1, 3)
364
- xk = xk.permute(0, 2, 1, 3)
365
- xv = xv.permute(0, 2, 1, 3)
366
- # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
367
- output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
368
- output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
369
- # output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
370
- output = self.out_proj(output)
371
- return output
372
-
373
- class NoTPTransformerBlock(nn.Module):
374
- def __init__(self, cfg, layer_id: int, multiple_of=256):
375
- super().__init__()
376
-
377
- self.n_heads = cfg.num_attention_heads
378
- self.dim = cfg.hidden_size
379
- self.head_dim = cfg.hidden_size // cfg.num_attention_heads
380
- self.self_attn = NoTPAttention(cfg)
381
- self.mlp = NoTPFeedForward(
382
- cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
383
- )
384
- self.layer_id = layer_id
385
- self.layer_norm1 = torch.nn.LayerNorm(
386
- cfg.hidden_size, eps=cfg.layernorm_epsilon
387
- )
388
- self.layer_norm2 = torch.nn.LayerNorm(
389
- cfg.hidden_size, eps=cfg.layernorm_epsilon
390
- )
391
-
392
- def forward(self, x: torch.Tensor):
393
- residual = self.self_attn.forward(self.layer_norm1(x))
394
- h = x + residual
395
- out = h + self.mlp.forward(self.layer_norm2(h))
396
- return out
397
-
398
-
399
- class NoTPTransformer(nn.Module):
400
- def __init__(self, cfg):
401
- super().__init__()
402
-
403
- self.cfg = cfg
404
- # self.recompute_list = self.cfg.get("recompute_list", [])
405
- self.num_layers = cfg.num_layers # _get_num_layers(cfg)
406
-
407
- self.layers = torch.nn.ModuleList()
408
- for layer_id in range(self.num_layers):
409
- self.layers.append(
410
- NoTPTransformerBlock(
411
- cfg,
412
- layer_id + 1,
413
- )
414
- )
415
-
416
- def forward(
417
- self,
418
- hidden_states,
419
- ):
420
-
421
- for lid, layer in enumerate(self.layers):
422
- # if lid in self.recompute_list:
423
- # def custom(layer_id):
424
- # def custom_forward(*args, **kwargs):
425
- # x_ = self.layers[layer_id](*args, **kwargs)
426
- # return x_
427
-
428
- # return custom_forward
429
-
430
- # assert hidden_states.requires_grad == True, logger.warning(
431
- # "When using recalculation, the input must have grad fn"
432
- # )
433
- # hidden_states = tensor_parallel.checkpoint(
434
- # custom(lid),
435
- # False,
436
- # hidden_states.contiguous()
437
- # )
438
- # else:
439
- hidden_states = layer(hidden_states)
440
-
441
- return hidden_states
442
-
443
-
444
- # from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter
445
-
446
- class VitModel(nn.Module):
447
- def __init__(
448
- self,
449
- cfg,
450
- freeze_embed=False,
451
- freeze_pre_norm=False
452
- ) -> None:
453
- super().__init__()
454
-
455
- self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)
456
-
457
- if freeze_embed:
458
- for name, param in self.embeddings.named_parameters():
459
- param.requires_grad = False
460
-
461
- self.transformer = NoTPTransformer(cfg=cfg)
462
-
463
- if cfg.get("fp32norm", False):
464
- logger.info("Load fp32 layernorm for ViT.")
465
- self.pre_layrnorm = LayerNormfp32(
466
- cfg.hidden_size,
467
- eps=cfg.get("pre_layernorm_epsilon", 1e-5),
468
- )
469
- else:
470
- self.pre_layrnorm = torch.nn.LayerNorm(
471
- cfg.hidden_size,
472
- eps=cfg.get("pre_layernorm_epsilon", 1e-5),
473
- )
474
-
475
- # self.pre_layrnorm = RMSNorm(
476
- # cfg.hidden_size,
477
- # eps=cfg.get("pre_layernorm_epsilon", 1e-5),
478
- # sequence_parallel=False,
479
- # use_fp32=True,
480
- # use_optimus=True,
481
- # )
482
-
483
- if freeze_pre_norm:
484
- for name, param in self.pre_layrnorm.named_parameters():
485
- param.requires_grad = False
486
-
487
- for p in self.parameters():
488
- p.micro_dp = True
489
-
490
- def set_input_tensor(self, input_tensor):
491
- if not isinstance(input_tensor, list):
492
- input_tensor = [input_tensor]
493
- self.transformer.set_input_tensor(input_tensor[0])
494
-
495
- def __str__(self) -> str:
496
- return "open_clip"
497
-
498
- def forward(
499
- self,
500
- x,
501
- patch_embeds
502
- ):
503
- x = self.embeddings(x, patch_embeds)
504
- hidden_states = self.pre_layrnorm(x)
505
-
506
- # hidden_states, dis = local_dp_scatter(hidden_states)
507
- output = self.transformer(hidden_states)
508
-
509
- # output = local_dp_reduce(output, dis)
510
-
511
- return output
512
-
513
-
514
- vit_model_cfg = adict(
515
- num_layers=24,
516
- hidden_size=1024,
517
- num_heads = 16,
518
- num_attention_heads=16,
519
- ffn_hidden_size=4096,
520
- seq_length=256,
521
- max_position_embeddings=256,
522
- use_flash_attn=False,
523
- understand_projector_stride=2,
524
- hidden_dropout = 0.0,
525
- attention_dropout = 0.0,
526
- no_persist_layer_norm = False,
527
- layernorm_epsilon = 1e-5,
528
- pre_layernorm_epsilon = 1e-5,
529
- image_size = 224,
530
- patch_size = 14,
531
- recompute_list = []
532
- )
533
-
534
- def build_clip_l():
535
- return VitModel(
536
- cfg=vit_model_cfg,
537
- freeze_embed=False,
538
- freeze_pre_norm=False,
539
- )
540
-
541
-
542
-
543
-
544
-
545
- #=========================Sam-Vary=================================
546
-
547
-
548
- def get_abs_pos_sam(abs_pos, tgt_size):
549
-
550
- dtype = abs_pos.dtype
551
-
552
- src_size = abs_pos.size(1)
553
-
554
- if src_size != tgt_size:
555
- old_pos_embed = abs_pos.permute(0, 3, 1, 2)
556
- old_pos_embed = old_pos_embed.to(torch.float32)
557
- new_pos_embed = F.interpolate(
558
- old_pos_embed,
559
- size=(tgt_size, tgt_size),
560
- mode='bicubic',
561
- antialias=True,
562
- align_corners=False,
563
- ).to(dtype)
564
- new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
565
- return new_pos_embed
566
- else:
567
- return abs_pos
568
-
569
-
570
-
571
-
572
- class MLPBlock(nn.Module):
573
- def __init__(
574
- self,
575
- embedding_dim: int,
576
- mlp_dim: int,
577
- act: Type[nn.Module] = nn.GELU,
578
- ) -> None:
579
- super().__init__()
580
- self.lin1 = nn.Linear(embedding_dim, mlp_dim)
581
- self.lin2 = nn.Linear(mlp_dim, embedding_dim)
582
- self.act = act()
583
-
584
- def forward(self, x: torch.Tensor) -> torch.Tensor:
585
- return self.lin2(self.act(self.lin1(x)))
586
-
587
-
588
- # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
589
- # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
590
- class LayerNorm2d(nn.Module):
591
- def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
592
- super().__init__()
593
- self.weight = nn.Parameter(torch.ones(num_channels))
594
- self.bias = nn.Parameter(torch.zeros(num_channels))
595
- self.eps = eps
596
-
597
- def forward(self, x: torch.Tensor) -> torch.Tensor:
598
- u = x.mean(1, keepdim=True)
599
- s = (x - u).pow(2).mean(1, keepdim=True)
600
- x = (x - u) / torch.sqrt(s + self.eps)
601
- x = self.weight[:, None, None] * x + self.bias[:, None, None]
602
- return x
603
-
604
-
605
- # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
606
- class ImageEncoderViT(nn.Module):
607
- def __init__(
608
- self,
609
- img_size: int = 1024,
610
- patch_size: int = 16,
611
- in_chans: int = 3,
612
- embed_dim: int = 768,
613
- depth: int = 12,
614
- num_heads: int = 12,
615
- mlp_ratio: float = 4.0,
616
- out_chans: int = 256,
617
- qkv_bias: bool = True,
618
- norm_layer: Type[nn.Module] = nn.LayerNorm,
619
- act_layer: Type[nn.Module] = nn.GELU,
620
- use_abs_pos: bool = True,
621
- use_rel_pos: bool = False,
622
- rel_pos_zero_init: bool = True,
623
- window_size: int = 0,
624
- global_attn_indexes: Tuple[int, ...] = (),
625
- ) -> None:
626
- """
627
- Args:
628
- img_size (int): Input image size.
629
- patch_size (int): Patch size.
630
- in_chans (int): Number of input image channels.
631
- embed_dim (int): Patch embedding dimension.
632
- depth (int): Depth of ViT.
633
- num_heads (int): Number of attention heads in each ViT block.
634
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
635
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
636
- norm_layer (nn.Module): Normalization layer.
637
- act_layer (nn.Module): Activation layer.
638
- use_abs_pos (bool): If True, use absolute positional embeddings.
639
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
640
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
641
- window_size (int): Window size for window attention blocks.
642
- global_attn_indexes (list): Indexes for blocks using global attention.
643
- """
644
- super().__init__()
645
- self.img_size = img_size
646
-
647
- self.patch_embed = PatchEmbed(
648
- kernel_size=(patch_size, patch_size),
649
- stride=(patch_size, patch_size),
650
- in_chans=in_chans,
651
- embed_dim=embed_dim,
652
- )
653
-
654
- self.pos_embed: Optional[nn.Parameter] = None
655
- if use_abs_pos:
656
- # Initialize absolute positional embedding with pretrain image size.
657
- self.pos_embed = nn.Parameter(
658
- torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
659
- )
660
-
661
- self.blocks = nn.ModuleList()
662
- for i in range(depth):
663
- block = Block(
664
- dim=embed_dim,
665
- num_heads=num_heads,
666
- mlp_ratio=mlp_ratio,
667
- qkv_bias=qkv_bias,
668
- norm_layer=norm_layer,
669
- act_layer=act_layer,
670
- use_rel_pos=use_rel_pos,
671
- rel_pos_zero_init=rel_pos_zero_init,
672
- window_size=window_size if i not in global_attn_indexes else 0,
673
- input_size=(img_size // patch_size, img_size // patch_size),
674
- )
675
- self.blocks.append(block)
676
-
677
- self.neck = nn.Sequential(
678
- nn.Conv2d(
679
- embed_dim,
680
- out_chans,
681
- kernel_size=1,
682
- bias=False,
683
- ),
684
- LayerNorm2d(out_chans),
685
- nn.Conv2d(
686
- out_chans,
687
- out_chans,
688
- kernel_size=3,
689
- padding=1,
690
- bias=False,
691
- ),
692
- LayerNorm2d(out_chans),
693
- )
694
-
695
- self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
696
- self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
697
-
698
- def forward(self, x: torch.Tensor) -> torch.Tensor:
699
- x = self.patch_embed(x)
700
- if self.pos_embed is not None:
701
- # x = x + self.pos_embed
702
- x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
703
-
704
- for blk in self.blocks:
705
- x = blk(x)
706
-
707
- x = self.neck(x.permute(0, 3, 1, 2))
708
- x2 = self.net_2(x)
709
- x3 = self.net_3(x2.clone())
710
-
711
- return x3
712
-
713
-
714
- class Block(nn.Module):
715
- """Transformer blocks with support of window attention and residual propagation blocks"""
716
-
717
- def __init__(
718
- self,
719
- dim: int,
720
- num_heads: int,
721
- mlp_ratio: float = 4.0,
722
- qkv_bias: bool = True,
723
- norm_layer: Type[nn.Module] = nn.LayerNorm,
724
- act_layer: Type[nn.Module] = nn.GELU,
725
- use_rel_pos: bool = False,
726
- rel_pos_zero_init: bool = True,
727
- window_size: int = 0,
728
- input_size: Optional[Tuple[int, int]] = None,
729
- ) -> None:
730
- """
731
- Args:
732
- dim (int): Number of input channels.
733
- num_heads (int): Number of attention heads in each ViT block.
734
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
735
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
736
- norm_layer (nn.Module): Normalization layer.
737
- act_layer (nn.Module): Activation layer.
738
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
739
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
740
- window_size (int): Window size for window attention blocks. If it equals 0, then
741
- use global attention.
742
- input_size (tuple(int, int) or None): Input resolution for calculating the relative
743
- positional parameter size.
744
- """
745
- super().__init__()
746
- self.norm1 = norm_layer(dim)
747
- self.attn = Attention(
748
- dim,
749
- num_heads=num_heads,
750
- qkv_bias=qkv_bias,
751
- use_rel_pos=use_rel_pos,
752
- rel_pos_zero_init=rel_pos_zero_init,
753
- input_size=input_size if window_size == 0 else (window_size, window_size),
754
- )
755
-
756
- self.norm2 = norm_layer(dim)
757
- self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
758
-
759
- self.window_size = window_size
760
-
761
- def forward(self, x: torch.Tensor) -> torch.Tensor:
762
- shortcut = x
763
- x = self.norm1(x)
764
- # Window partition
765
- if self.window_size > 0:
766
- H, W = x.shape[1], x.shape[2]
767
- x, pad_hw = window_partition(x, self.window_size)
768
-
769
- x = self.attn(x)
770
- # Reverse window partition
771
- if self.window_size > 0:
772
- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
773
-
774
- x = shortcut + x
775
- x = x + self.mlp(self.norm2(x))
776
-
777
- return x
778
-
779
-
780
- class Attention(nn.Module):
781
- """Multi-head Attention block with relative position embeddings."""
782
-
783
- def __init__(
784
- self,
785
- dim: int,
786
- num_heads: int = 8,
787
- qkv_bias: bool = True,
788
- use_rel_pos: bool = False,
789
- rel_pos_zero_init: bool = True,
790
- input_size: Optional[Tuple[int, int]] = None,
791
- ) -> None:
792
- """
793
- Args:
794
- dim (int): Number of input channels.
795
- num_heads (int): Number of attention heads.
796
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
797
- rel_pos (bool): If True, add relative positional embeddings to the attention map.
798
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
799
- input_size (tuple(int, int) or None): Input resolution for calculating the relative
800
- positional parameter size.
801
- """
802
- super().__init__()
803
- self.num_heads = num_heads
804
- head_dim = dim // num_heads
805
- self.scale = head_dim**-0.5
806
-
807
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
808
- self.proj = nn.Linear(dim, dim)
809
-
810
- self.use_rel_pos = use_rel_pos
811
- if self.use_rel_pos:
812
- assert (
813
- input_size is not None
814
- ), "Input size must be provided if using relative positional encoding."
815
- # initialize relative positional embeddings
816
- self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
817
- self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
818
-
819
- def forward(self, x: torch.Tensor) -> torch.Tensor:
820
- B, H, W, _ = x.shape
821
- # qkv with shape (3, B, nHead, H * W, C)
822
- qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
823
- # q, k, v with shape (B * nHead, H * W, C)
824
- q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
825
-
826
- rel_h, rel_w = None, None
827
- if self.use_rel_pos:
828
- rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
829
-
830
- q = q.view(B, self.num_heads, H * W, -1)
831
- k = k.view(B, self.num_heads, H * W, -1)
832
- v = v.view(B, self.num_heads, H * W, -1)
833
-
834
- if self.use_rel_pos:
835
- rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
836
- rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
837
- attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
838
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
839
- # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
840
- else:
841
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
842
-
843
- x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
844
-
845
- x = self.proj(x)
846
-
847
- return x
848
-
849
-
850
- def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
851
- """
852
- Partition into non-overlapping windows with padding if needed.
853
- Args:
854
- x (tensor): input tokens with [B, H, W, C].
855
- window_size (int): window size.
856
-
857
- Returns:
858
- windows: windows after partition with [B * num_windows, window_size, window_size, C].
859
- (Hp, Wp): padded height and width before partition
860
- """
861
- B, H, W, C = x.shape
862
-
863
- pad_h = (window_size - H % window_size) % window_size
864
- pad_w = (window_size - W % window_size) % window_size
865
- if pad_h > 0 or pad_w > 0:
866
- x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
867
- Hp, Wp = H + pad_h, W + pad_w
868
-
869
- x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
870
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
871
- return windows, (Hp, Wp)
872
-
873
-
874
- def window_unpartition(
875
- windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
876
- ) -> torch.Tensor:
877
- """
878
- Window unpartition into original sequences and removing padding.
879
- Args:
880
- windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
881
- window_size (int): window size.
882
- pad_hw (Tuple): padded height and width (Hp, Wp).
883
- hw (Tuple): original height and width (H, W) before padding.
884
-
885
- Returns:
886
- x: unpartitioned sequences with [B, H, W, C].
887
- """
888
- Hp, Wp = pad_hw
889
- H, W = hw
890
- B = windows.shape[0] // (Hp * Wp // window_size // window_size)
891
- x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
892
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
893
-
894
- if Hp > H or Wp > W:
895
- x = x[:, :H, :W, :].contiguous()
896
- return x
897
-
898
-
899
- def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
900
- """
901
- Get relative positional embeddings according to the relative positions of
902
- query and key sizes.
903
- Args:
904
- q_size (int): size of query q.
905
- k_size (int): size of key k.
906
- rel_pos (Tensor): relative position embeddings (L, C).
907
-
908
- Returns:
909
- Extracted positional embeddings according to relative positions.
910
- """
911
- max_rel_dist = int(2 * max(q_size, k_size) - 1)
912
- # Interpolate rel pos if needed.
913
- if rel_pos.shape[0] != max_rel_dist:
914
- # Interpolate rel pos.
915
- dtype = rel_pos.dtype
916
- rel_pos = rel_pos.to(torch.float32)
917
- rel_pos_resized = F.interpolate(
918
- rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
919
- size=max_rel_dist,
920
- mode="linear",
921
- ).to(dtype)
922
- rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
923
- else:
924
- rel_pos_resized = rel_pos
925
-
926
- # Scale the coords with short length if shapes for q and k are different.
927
- q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
928
- k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
929
- relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
930
-
931
- return rel_pos_resized[relative_coords.long()]
932
-
933
-
934
- def add_decomposed_rel_pos(
935
- q: torch.Tensor,
936
- rel_pos_h: torch.Tensor,
937
- rel_pos_w: torch.Tensor,
938
- q_size: Tuple[int, int],
939
- k_size: Tuple[int, int],
940
- ) -> torch.Tensor:
941
- """
942
- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
943
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
944
- Args:
945
- q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
946
- rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
947
- rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
948
- q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
949
- k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
950
-
951
- Returns:
952
- attn (Tensor): attention map with added relative positional embeddings.
953
- """
954
- q_h, q_w = q_size
955
- k_h, k_w = k_size
956
- Rh = get_rel_pos(q_h, k_h, rel_pos_h)
957
- Rw = get_rel_pos(q_w, k_w, rel_pos_w)
958
-
959
- B, _, dim = q.shape
960
- r_q = q.reshape(B, q_h, q_w, dim)
961
- rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
962
- rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
963
- rel_h = rel_h.unsqueeze(-1)
964
- rel_w = rel_w.unsqueeze(-2)
965
- rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
966
- rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
967
-
968
- return rel_h, rel_w
969
-
970
-
971
- class PatchEmbed(nn.Module):
972
- """
973
- Image to Patch Embedding.
974
- """
975
-
976
- def __init__(
977
- self,
978
- kernel_size: Tuple[int, int] = (16, 16),
979
- stride: Tuple[int, int] = (16, 16),
980
- padding: Tuple[int, int] = (0, 0),
981
- in_chans: int = 3,
982
- embed_dim: int = 768,
983
- ) -> None:
984
- """
985
- Args:
986
- kernel_size (Tuple): kernel size of the projection layer.
987
- stride (Tuple): stride of the projection layer.
988
- padding (Tuple): padding size of the projection layer.
989
- in_chans (int): Number of input image channels.
990
- embed_dim (int): Patch embedding dimension.
991
- """
992
- super().__init__()
993
-
994
- self.proj = nn.Conv2d(
995
- in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
996
- )
997
-
998
- def forward(self, x: torch.Tensor) -> torch.Tensor:
999
- x = self.proj(x)
1000
- # B C H W -> B H W C
1001
- x = x.permute(0, 2, 3, 1)
1002
- return x
1003
-
1004
-
1005
- def build_sam_vit_b(checkpoint=None):
1006
- return _build_sam(
1007
- encoder_embed_dim=768,
1008
- encoder_depth=12,
1009
- encoder_num_heads=12,
1010
- encoder_global_attn_indexes=[2, 5, 8, 11],
1011
- checkpoint=checkpoint,
1012
- )
1013
-
1014
- def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
1015
- image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
1016
- # sam = _apply_eval_dtype_sam(sam, dtype)
1017
- image_encoder = torch.compile(image_encoder, mode=compile_mode)
1018
- return image_encoder
1019
-
1020
-
1021
- def _build_sam(
1022
- encoder_embed_dim,
1023
- encoder_depth,
1024
- encoder_num_heads,
1025
- encoder_global_attn_indexes,
1026
- checkpoint=None,
1027
- ):
1028
- prompt_embed_dim = 256
1029
- image_size = 1024
1030
- vit_patch_size = 16
1031
- image_embedding_size = image_size // vit_patch_size
1032
- image_encoder=ImageEncoderViT(
1033
- depth=encoder_depth,
1034
- embed_dim=encoder_embed_dim,
1035
- img_size=image_size,
1036
- mlp_ratio=4,
1037
- norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
1038
- num_heads=encoder_num_heads,
1039
- patch_size=vit_patch_size,
1040
- qkv_bias=True,
1041
- use_rel_pos=True,
1042
- global_attn_indexes=encoder_global_attn_indexes,
1043
- window_size=14,
1044
- out_chans=prompt_embed_dim,
1045
- )
1046
- image_encoder.eval()
1047
- if checkpoint is not None:
1048
- # with open(checkpoint, "rb") as f:
1049
- state_dict = torch.load(checkpoint)
1050
- # print(state_dict.keys())
1051
- # for key in state_dict:
1052
- # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
1053
- # ocr-anyting
1054
- # image_encoder.load_state_dict(state_dict, strict=True)
1055
- # tob
1056
- image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
1057
- print(checkpoint)
1058
- return image_encoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.ipynb_checkpoints/modeling_deepseekocr-checkpoint.py DELETED
@@ -1,1043 +0,0 @@
1
- # MIT License modified from prithivMLmods/DeepSeek-OCR-Latest-BF16.I64
2
- import os
3
- import math
4
- import re
5
- from tqdm import tqdm
6
- from abc import ABC
7
- from typing import List, Optional, Tuple, Union
8
-
9
- from addict import Dict
10
- from PIL import Image, ImageOps, ImageDraw, ImageFont
11
- import numpy as np
12
-
13
- import torch
14
- import torch.nn as nn
15
- from torch.nn import CrossEntropyLoss
16
- from torchvision import transforms
17
-
18
- from transformers.cache_utils import Cache
19
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
20
- from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
21
- from transformers import DeepseekV2Config
22
- from transformers.models.deepseek_v2.modeling_deepseek_v2 import (
23
- DeepseekV2Attention, DeepseekV2MLP, DeepseekV2MoE, DeepseekV2RMSNorm, DeepseekV2DecoderLayer)
24
- from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
25
- from transformers import TextStreamer
26
- from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector
27
- from .conversation import get_conv_template
28
-
29
- torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
30
-
31
- def load_image(image_path):
32
-
33
- try:
34
- image = Image.open(image_path)
35
-
36
- corrected_image = ImageOps.exif_transpose(image)
37
-
38
- return corrected_image
39
-
40
- except Exception as e:
41
- print(f"error: {e}")
42
- try:
43
- return Image.open(image_path)
44
- except:
45
- return None
46
-
47
-
48
- def re_match(text):
49
- pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
50
- matches = re.findall(pattern, text, re.DOTALL)
51
-
52
- # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
53
- # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL)
54
-
55
- mathes_image = []
56
- mathes_other = []
57
- for a_match in matches:
58
- if '<|ref|>image<|/ref|>' in a_match[0]:
59
- mathes_image.append(a_match[0])
60
- else:
61
- mathes_other.append(a_match[0])
62
- return matches, mathes_image, mathes_other
63
-
64
-
65
- def extract_coordinates_and_label(ref_text, image_width, image_height):
66
-
67
- try:
68
- label_type = ref_text[1]
69
- cor_list = eval(ref_text[2])
70
- except Exception as e:
71
- print(e)
72
- return None
73
-
74
- return (label_type, cor_list)
75
-
76
-
77
- def draw_bounding_boxes(image, refs, ouput_path):
78
-
79
- image_width, image_height = image.size
80
-
81
- img_draw = image.copy()
82
- draw = ImageDraw.Draw(img_draw)
83
-
84
- overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
85
- draw2 = ImageDraw.Draw(overlay)
86
-
87
- # try:
88
- # except IOError:
89
- # try:
90
- # font = ImageFont.truetype("DejaVuSans.ttf", 20)
91
- # except IOError:
92
- font = ImageFont.load_default()
93
-
94
- img_idx = 0
95
-
96
- for i, ref in enumerate(refs):
97
- try:
98
- result = extract_coordinates_and_label(ref, image_width, image_height)
99
- if result:
100
- label_type, points_list = result
101
-
102
- color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
103
-
104
- color_a = color + (20, )
105
- for points in points_list:
106
- x1, y1, x2, y2 = points
107
-
108
- x1 = int(x1 / 999 * image_width)
109
- y1 = int(y1 / 999 * image_height)
110
-
111
- x2 = int(x2 / 999 * image_width)
112
- y2 = int(y2 / 999 * image_height)
113
-
114
- if label_type == 'image':
115
- try:
116
- cropped = image.crop((x1, y1, x2, y2))
117
- cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
118
- except Exception as e:
119
- print(e)
120
- pass
121
- img_idx += 1
122
-
123
- try:
124
- if label_type == 'title':
125
- draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
126
- draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
127
- else:
128
- draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
129
- draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
130
- text_x = x1
131
- text_y = max(0, y1 - 15)
132
-
133
-
134
- text_bbox = draw.textbbox((0, 0), label_type, font=font)
135
- text_width = text_bbox[2] - text_bbox[0]
136
- text_height = text_bbox[3] - text_bbox[1]
137
- draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
138
- fill=(255, 255, 255, 30))
139
-
140
- draw.text((text_x, text_y), label_type, font=font, fill=color)
141
- except:
142
- pass
143
- except:
144
- continue
145
- img_draw.paste(overlay, (0, 0), overlay)
146
- return img_draw
147
-
148
-
149
- def process_image_with_refs(image, ref_texts, output_path):
150
-
151
- result_image = draw_bounding_boxes(image, ref_texts, output_path)
152
-
153
- return result_image
154
-
155
-
156
-
157
-
158
-
159
- def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
160
- best_ratio_diff = float('inf')
161
- best_ratio = (1, 1)
162
- area = width * height
163
- for ratio in target_ratios:
164
- target_aspect_ratio = ratio[0] / ratio[1]
165
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
166
- if ratio_diff < best_ratio_diff:
167
- best_ratio_diff = ratio_diff
168
- best_ratio = ratio
169
- elif ratio_diff == best_ratio_diff:
170
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
171
- best_ratio = ratio
172
- # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
173
- return best_ratio
174
-
175
-
176
- def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False):
177
- orig_width, orig_height = image.size
178
- aspect_ratio = orig_width / orig_height
179
-
180
- # calculate the existing image aspect ratio
181
- target_ratios = set(
182
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
183
- i * j <= max_num and i * j >= min_num)
184
- # print(target_ratios)
185
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
186
-
187
- # find the closest aspect ratio to the target
188
- target_aspect_ratio = find_closest_aspect_ratio(
189
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
190
-
191
- # print(target_aspect_ratio)
192
- # calculate the target width and height
193
- target_width = image_size * target_aspect_ratio[0]
194
- target_height = image_size * target_aspect_ratio[1]
195
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
196
-
197
- # resize the image
198
- resized_img = image.resize((target_width, target_height))
199
- processed_images = []
200
- for i in range(blocks):
201
- box = (
202
- (i % (target_width // image_size)) * image_size,
203
- (i // (target_width // image_size)) * image_size,
204
- ((i % (target_width // image_size)) + 1) * image_size,
205
- ((i // (target_width // image_size)) + 1) * image_size
206
- )
207
- # split the image
208
- split_img = resized_img.crop(box)
209
- processed_images.append(split_img)
210
- assert len(processed_images) == blocks
211
- if use_thumbnail and len(processed_images) != 1:
212
- thumbnail_img = image.resize((image_size, image_size))
213
- processed_images.append(thumbnail_img)
214
- return processed_images, target_aspect_ratio
215
-
216
-
217
-
218
- def normalize_transform(mean, std):
219
- if mean is None and std is None:
220
- transform = None
221
- elif mean is None and std is not None:
222
- mean = [0.] * len(std)
223
- transform = transforms.Normalize(mean=mean, std=std)
224
- elif mean is not None and std is None:
225
- std = [1.] * len(mean)
226
- transform = transforms.Normalize(mean=mean, std=std)
227
- else:
228
- transform = transforms.Normalize(mean=mean, std=std)
229
-
230
- return transform
231
-
232
-
233
-
234
- def format_messages(
235
- conversations: List[Dict[str, str]],
236
- sft_format: str = "deepseek",
237
- system_prompt: str = "",
238
- ):
239
- """
240
- Applies the SFT template to conversation.
241
-
242
- Args:
243
- conversations (List[Dict]): A List of messages.
244
- sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
245
- system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
246
-
247
- Returns:
248
- sft_prompt (str): The formatted text.
249
- """
250
-
251
- conv = get_conv_template(sft_format)
252
- conv.set_system_message(system_prompt)
253
- for message in conversations:
254
- conv.append_message(message["role"], message["content"].strip())
255
- sft_prompt = conv.get_prompt().strip()
256
-
257
- return sft_prompt
258
-
259
-
260
- def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
261
- t = tokenizer.encode(text, add_special_tokens=False)
262
- bos_id = 0
263
- eos_id = 1
264
- if bos:
265
- t = [bos_id] + t
266
- if eos:
267
- t = t + [eos_id]
268
-
269
- return t
270
-
271
- def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
272
- """
273
-
274
- Args:
275
- conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
276
- [
277
- {
278
- "role": "User",
279
- "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
280
- "images": ["./examples/table_datasets.png"]
281
- },
282
- {"role": "Assistant", "content": ""},
283
- ]
284
-
285
- Returns:
286
- pil_images (List[PIL.Image.Image]): the list of PIL images.
287
-
288
- """
289
-
290
- pil_images = []
291
-
292
- for message in conversations:
293
- if "images" not in message:
294
- continue
295
-
296
- for image_path in message["images"]:
297
- # print('----------------')
298
- # print(image_path)
299
- # print('----------------')
300
- # exit()
301
-
302
- # pil_img = Image.open(image_path)
303
- pil_img = load_image(image_path)
304
- pil_img = pil_img.convert("RGB")
305
- pil_images.append(pil_img)
306
-
307
- return pil_images
308
-
309
-
310
- class BaseTransform(ABC):
311
-
312
- def set_rng(self, *args, **kwargs):
313
- pass
314
-
315
- def __call__(self, *args, **kwargs) -> torch.Tensor:
316
- pass
317
-
318
- @property
319
- def default_shape(self):
320
- raise NotImplementedError
321
-
322
-
323
- class BasicImageTransform(BaseTransform):
324
- def __init__(
325
- self,
326
- mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
327
- std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
328
- normalize: bool = True
329
- ):
330
- self.mean = mean
331
- self.std = std
332
-
333
- transform_pipelines = [
334
- transforms.ToTensor()
335
- ]
336
-
337
- normalize = normalize_transform(mean, std) if normalize else nn.Identity()
338
- if normalize is not None:
339
- transform_pipelines.append(normalize)
340
-
341
- self.transform = transforms.Compose(transform_pipelines)
342
-
343
- def __call__(self, x):
344
- x = self.transform(x)
345
- return x
346
-
347
- class NoEOSTextStreamer(TextStreamer):
348
- def on_finalized_text(self, text: str, stream_end: bool = False):
349
-
350
- eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
351
- text = text.replace(eos_text, "\n")
352
- print(text, flush=True, end="")
353
-
354
-
355
- def decoder_layer_init(self, config: DeepseekV2Config, layer_idx: int):
356
- nn.Module.__init__(self)
357
- self.hidden_size = config.hidden_size
358
-
359
- if config.use_mla:
360
- self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx)
361
- else:
362
- config.head_dim = config.hidden_size // config.num_attention_heads
363
- self.self_attn = LlamaAttention(config, layer_idx)
364
- self.mlp = DeepseekV2MoE(config) if layer_idx >= config.first_k_dense_replace else DeepseekV2MLP(config)
365
-
366
- self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
367
- self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
368
-
369
-
370
- DeepseekV2DecoderLayer.__init__ = decoder_layer_init
371
-
372
- class DeepseekOCRConfig(DeepseekV2Config):
373
- model_type = "DeepseekOCR"
374
-
375
- class DeepseekOCRModel(DeepseekV2Model):
376
- config_class = DeepseekOCRConfig
377
-
378
- def __init__(self, config: DeepseekV2Config):
379
- super(DeepseekOCRModel, self).__init__(config)
380
-
381
- self.sam_model = build_sam_vit_b()
382
- self.vision_model = build_clip_l()
383
- # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
384
- n_embed = 1280
385
- self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
386
- embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
387
- self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
388
- self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
389
-
390
- self.rotary_emb = LlamaRotaryEmbedding(config=config)
391
-
392
- def forward(
393
- self,
394
- input_ids: torch.LongTensor = None,
395
- attention_mask: Optional[torch.Tensor] = None,
396
- position_ids: Optional[torch.LongTensor] = None,
397
- past_key_values: Optional[List[torch.FloatTensor]] = None,
398
- inputs_embeds: Optional[torch.FloatTensor] = None,
399
- use_cache: Optional[bool] = None,
400
- output_attentions: Optional[bool] = None,
401
- output_hidden_states: Optional[bool] = None,
402
- images: Optional[torch.FloatTensor] = None,
403
- images_seq_mask: Optional[torch.FloatTensor] = None,
404
- images_spatial_crop: Optional[torch.FloatTensor] = None,
405
- return_dict: Optional[bool] = None,
406
- ) -> Union[Tuple, BaseModelOutputWithPast]:
407
-
408
-
409
-
410
- if inputs_embeds is None:
411
- # inputs_embeds = self.embed_tokens(input_ids)
412
- inputs_embeds = self.get_input_embeddings()(input_ids)
413
-
414
- inputs_embeds = inputs_embeds.clone()
415
-
416
- sam_model = getattr(self, 'sam_model', None)
417
- # sam_model = self.sam_model
418
- vision_model = getattr(self, 'vision_model', None)
419
-
420
-
421
-
422
- if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
423
-
424
- idx = 0
425
-
426
- # sam_model = torch.jit.script(sam_model)
427
-
428
- # start_time = time.time()
429
- for image, crop_shape in zip(images, images_spatial_crop):
430
- images_in_this_batch = []
431
-
432
- patches = image[0]
433
- image_ori = image[1]
434
-
435
- with torch.no_grad():
436
- # with torch.inference_mode():
437
-
438
- if torch.sum(patches).item() != 0:
439
- # P, C, H, W = patches.shape
440
- crop_flag = 1
441
- local_features_1 = sam_model(patches)
442
-
443
- local_features_2 = vision_model(patches, local_features_1)
444
- # vit_time = time.time()
445
- local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
446
- local_features = self.projector(local_features)
447
-
448
-
449
- global_features_1 = sam_model(image_ori)
450
- global_features_2 = vision_model(image_ori, global_features_1)
451
- global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
452
- global_features = self.projector(global_features)
453
-
454
- # print('=====================')
455
- # print('BASE: ', global_features.shape)
456
- # print('PATCHES: ', local_features.shape)
457
- # print('=====================')
458
-
459
- _, hw, n_dim = global_features.shape
460
- h = w = int(hw ** 0.5)
461
-
462
- _2, hw2, n_dim2 = local_features.shape
463
- h2 = w2 = int(hw2 ** 0.5)
464
-
465
- width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
466
-
467
- global_features = global_features.view(h, w, n_dim)
468
-
469
- global_features = torch.cat(
470
- [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
471
- )
472
-
473
- global_features = global_features.view(-1, n_dim)
474
-
475
-
476
- local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
477
- local_features = torch.cat(
478
- [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
479
- )
480
- local_features = local_features.view(-1, n_dim2)
481
-
482
- global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
483
-
484
- # end_time = time.time()
485
-
486
- # print('sam: ', sam_time - start_time)
487
- # print('vit: ', vit_time - sam_time)
488
- # print('all: ', end_time - start_time)
489
-
490
- # exit()
491
-
492
- else:
493
- global_features_1 = sam_model(image_ori)
494
- global_features_2 = vision_model(image_ori, global_features_1)
495
- global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
496
- global_features = self.projector(global_features)
497
- _, hw, n_dim = global_features.shape
498
- h = w = int(hw ** 0.5)
499
-
500
-
501
- global_features = global_features.view(h, w, n_dim)
502
-
503
- global_features = torch.cat(
504
- [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
505
- )
506
-
507
- global_features = global_features.view(-1, n_dim)
508
-
509
- global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
510
-
511
- images_in_this_batch.append(global_local_features)
512
-
513
-
514
- if images_in_this_batch:
515
- images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
516
- images_in_this_batch = images_in_this_batch.to(
517
- device=inputs_embeds.device, dtype=inputs_embeds.dtype
518
- )
519
- mask = images_seq_mask[idx].unsqueeze(-1).to(inputs_embeds.device) # bool [T, 1]
520
- updated_row = inputs_embeds[idx].masked_scatter(mask, images_in_this_batch)
521
- inputs_embeds[idx] = updated_row
522
-
523
- idx += 1
524
-
525
- return super(DeepseekOCRModel, self).forward(
526
- input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
527
- inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
528
- output_attentions=output_attentions, output_hidden_states=output_hidden_states,
529
- return_dict=return_dict
530
- )
531
-
532
-
533
- class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
534
-
535
- config_class = DeepseekOCRConfig
536
- # supports_gradient_checkpointing = True
537
-
538
- def __init__(self, config):
539
- super(DeepseekV2ForCausalLM, self).__init__(config)
540
- self.model = DeepseekOCRModel(config)
541
-
542
- self.vocab_size = config.vocab_size
543
-
544
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
545
-
546
- # Initialize weights and apply final processing
547
- self.post_init()
548
-
549
- def get_model(self):
550
- return self.model
551
-
552
-
553
- def forward(
554
- self,
555
- input_ids: torch.LongTensor = None,
556
- attention_mask: Optional[torch.Tensor] = None,
557
- position_ids: Optional[torch.LongTensor] = None,
558
- past_key_values: Optional[List[torch.FloatTensor]] = None,
559
- inputs_embeds: Optional[torch.FloatTensor] = None,
560
- labels: Optional[torch.LongTensor] = None,
561
- use_cache: Optional[bool] = None,
562
- output_attentions: Optional[bool] = None,
563
- output_hidden_states: Optional[bool] = None,
564
- images: Optional[torch.FloatTensor] = None,
565
- images_seq_mask: Optional[torch.FloatTensor] = None,
566
- images_spatial_crop: Optional[torch.FloatTensor] = None,
567
- return_dict: Optional[bool] = None,
568
-
569
- ) -> Union[Tuple, CausalLMOutputWithPast]:
570
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
571
- output_hidden_states = (
572
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
573
- )
574
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
575
-
576
-
577
-
578
- outputs = self.model(
579
- input_ids=input_ids,
580
- past_key_values=past_key_values,
581
- attention_mask=attention_mask,
582
- position_ids=position_ids,
583
- inputs_embeds=inputs_embeds,
584
- use_cache=use_cache,
585
- output_attentions=output_attentions,
586
- output_hidden_states=output_hidden_states,
587
- images=images,
588
- images_seq_mask = images_seq_mask,
589
- images_spatial_crop = images_spatial_crop,
590
- return_dict=return_dict
591
-
592
- )
593
-
594
- hidden_states = outputs[0]
595
- logits = self.lm_head(hidden_states)
596
- logits = logits.float()
597
-
598
- # logits
599
-
600
- loss = None
601
- if labels is not None:
602
- # Shift so that tokens < n predict n
603
- shift_logits = logits[..., :-1, :].contiguous()
604
- shift_labels = labels[..., 1:].contiguous()
605
- # Flatten the tokens
606
- loss_fct = CrossEntropyLoss()
607
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
608
- shift_labels = shift_labels.view(-1)
609
- # Enable model parallelism
610
- shift_labels = shift_labels.to(shift_logits.device)
611
- loss = loss_fct(shift_logits, shift_labels)
612
-
613
- if not return_dict:
614
- output = (logits,) + outputs[1:]
615
- return (loss,) + output if loss is not None else output
616
-
617
- return CausalLMOutputWithPast(
618
- loss=loss,
619
- logits=logits,
620
- past_key_values=outputs.past_key_values,
621
- hidden_states=outputs.hidden_states,
622
- attentions=outputs.attentions,
623
- )
624
-
625
-
626
- def prepare_inputs_for_generation(
627
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
628
- ):
629
- # Omit tokens covered by past_key_values
630
- past_length = 0
631
- if past_key_values is not None:
632
- if isinstance(past_key_values, Cache):
633
- cache_length = past_key_values.get_seq_length()
634
- past_length = past_key_values.get_seq_length()
635
- max_cache_length = None
636
- else:
637
- cache_length = past_length = past_key_values[0][0].shape[2]
638
- max_cache_length = None
639
-
640
- # Keep only the unprocessed tokens:
641
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
642
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
643
- # input)
644
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
645
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
646
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
647
- # input_ids based on the past_length.
648
- elif past_length < input_ids.shape[1]:
649
- input_ids = input_ids[:, past_length:]
650
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
651
-
652
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
653
- if (
654
- max_cache_length is not None
655
- and attention_mask is not None
656
- and cache_length + input_ids.shape[1] > max_cache_length
657
- ):
658
- attention_mask = attention_mask[:, -max_cache_length:]
659
-
660
- position_ids = kwargs.get("position_ids", None)
661
- if attention_mask is not None and position_ids is None:
662
- # create position_ids on the fly for batch generation
663
- position_ids = attention_mask.long().cumsum(-1) - 1
664
- position_ids.masked_fill_(attention_mask == 0, 1)
665
- if past_key_values:
666
- position_ids = position_ids[:, -input_ids.shape[1] :]
667
-
668
- # if self.generation_config.cache_implementation == "static":
669
- # # generation with static cache
670
- # cache_position = kwargs.get("cache_position", None)
671
- # if cache_position is None:
672
- # past_length = 0
673
- # else:
674
- # past_length = cache_position[-1] + 1
675
- # input_ids = input_ids[:, past_length:]
676
- # position_ids = position_ids[:, past_length:]
677
-
678
- # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
679
- # same goes for position ids. Could also help with continued generation.
680
- cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
681
-
682
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
683
- if inputs_embeds is not None and past_key_values is None:
684
- model_inputs = {"inputs_embeds": inputs_embeds}
685
- else:
686
- model_inputs = {"input_ids": input_ids}
687
-
688
- model_inputs.update(
689
- {
690
- "position_ids": position_ids,
691
- "past_key_values": past_key_values,
692
- "use_cache": kwargs.get("use_cache"),
693
- "attention_mask": attention_mask,
694
- "images": kwargs.get("images", None),
695
- "images_seq_mask": kwargs.get("images_seq_mask", None),
696
- "images_spatial_crop": kwargs.get("images_spatial_crop", None),
697
- }
698
- )
699
- return model_inputs
700
-
701
-
702
- def disable_torch_init(self):
703
- """
704
- Disable the redundant torch default initialization to accelerate model creation.
705
- """
706
- import torch
707
- setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
708
- setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
709
-
710
-
711
-
712
- def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
713
- self.disable_torch_init()
714
-
715
- os.makedirs(output_path, exist_ok=True)
716
- os.makedirs(f'{output_path}/images', exist_ok=True)
717
-
718
- if prompt and image_file:
719
- conversation = [
720
- {
721
- "role": "<|User|>",
722
- # "content": "<image>\n<|grounding|>Given the layout of the image. ",
723
- "content": f'{prompt}',
724
- # "content": "君不见黄河之水天上来的下一句是什么?",
725
- # "content": "<image>\nFree OCR. ",
726
- # "content": "<image>\nParse the figure. ",
727
- # "content": "<image>\nExtract the text in the image. ",
728
- "images": [f'{image_file}'],
729
- },
730
- {"role": "<|Assistant|>", "content": ""},
731
- ]
732
-
733
- elif prompt:
734
- conversation = [
735
- {
736
- "role": "<|User|>",
737
- # "content": "<image>\n<|grounding|>Given the layout of the image. ",
738
- "content": f'{prompt}',
739
- # "content": "君不见黄河之水天上来的下一句是什么?",
740
- # "content": "<image>\nFree OCR. ",
741
- # "content": "<image>\nParse the figure. ",
742
- # "content": "<image>\nExtract the text in the image. ",
743
- # "images": [f'{image_file}'],
744
- },
745
- {"role": "<|Assistant|>", "content": ""},
746
- ]
747
- else:
748
- assert False, f'prompt is none!'
749
-
750
- prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')
751
-
752
- patch_size = 16
753
- downsample_ratio = 4
754
- images = load_pil_images(conversation)
755
-
756
- valid_img_tokens = 0
757
- ratio = 1
758
-
759
- image_draw = images[0].copy()
760
-
761
- w,h = image_draw.size
762
- # print(w, h)
763
- ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
764
-
765
-
766
- image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
767
- images_seq_mask = []
768
-
769
- image_token = '<image>'
770
- image_token_id = 128815
771
- text_splits = prompt.split(image_token)
772
-
773
- images_list, images_crop_list, images_seq_mask = [], [], []
774
- tokenized_str = []
775
- images_spatial_crop = []
776
- for text_sep, image in zip(text_splits, images):
777
-
778
- tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
779
- tokenized_str += tokenized_sep
780
- images_seq_mask += [False] * len(tokenized_sep)
781
-
782
- if crop_mode:
783
-
784
- if image.size[0] <= 640 and image.size[1] <= 640:
785
- crop_ratio = [1, 1]
786
-
787
- else:
788
- if crop_mode:
789
- # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
790
- images_crop_raw, crop_ratio = dynamic_preprocess(image)
791
- else:
792
- # best_width, best_height = self.image_size, self.image_size
793
- crop_ratio = [1, 1]
794
-
795
- """process the global view"""
796
- # image = image.resize((base_size, base_size))
797
- global_view = ImageOps.pad(image, (base_size, base_size),
798
- color=tuple(int(x * 255) for x in image_transform.mean))
799
-
800
- if base_size == 1024:
801
- valid_img_tokens += int(256 * ratio)
802
- elif base_size == 1280:
803
- valid_img_tokens += int(400 * ratio)
804
- # elif base_size == 640:
805
- # valid_img_tokens += int(100 * ratio)
806
-
807
-
808
-
809
-
810
-
811
- images_list.append(image_transform(global_view).to(torch_dtype))
812
-
813
- # global_view_tensor = image_transform(global_view).to(torch_dtype)
814
-
815
- width_crop_num, height_crop_num = crop_ratio
816
-
817
- images_spatial_crop.append([width_crop_num, height_crop_num])
818
-
819
-
820
- if width_crop_num > 1 or height_crop_num > 1:
821
- """process the local views"""
822
-
823
- for i in range(len(images_crop_raw)):
824
- images_crop_list.append(image_transform(images_crop_raw[i]).to(torch_dtype))
825
-
826
- if image_size == 640:
827
- valid_img_tokens += len(images_crop_list) * 100
828
-
829
- num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
830
- num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
831
-
832
-
833
-
834
- """add image tokens"""
835
-
836
-
837
-
838
- tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base
839
- tokenized_image += [image_token_id]
840
- if width_crop_num > 1 or height_crop_num > 1:
841
- tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * (
842
- num_queries * height_crop_num)
843
- tokenized_str += tokenized_image
844
- images_seq_mask += [True] * len(tokenized_image)
845
- # num_image_tokens.append(len(tokenized_image))
846
-
847
- else:
848
- # best_width, best_height = self.image_size, self.image_size
849
- # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
850
-
851
- """process the global view"""
852
- if image_size <= 640:
853
- print('directly resize')
854
- image = image.resize((image_size, image_size))
855
- # else:
856
- global_view = ImageOps.pad(image, (image_size, image_size),
857
- color=tuple(int(x * 255) for x in image_transform.mean))
858
- images_list.append(image_transform(global_view).to(torch_dtype))
859
-
860
- if base_size == 1024:
861
- valid_img_tokens += int(256 * ratio)
862
- elif base_size == 1280:
863
- valid_img_tokens += int(400 * ratio)
864
- elif base_size == 640:
865
- valid_img_tokens += int(100 * 1)
866
- elif base_size == 512:
867
- valid_img_tokens += int(64 * 1)
868
-
869
- width_crop_num, height_crop_num = 1, 1
870
-
871
- images_spatial_crop.append([width_crop_num, height_crop_num])
872
-
873
-
874
- """add image tokens"""
875
- num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
876
-
877
- tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries
878
- tokenized_image += [image_token_id]
879
- # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
880
- # num_queries * height_crop_num)
881
- tokenized_str += tokenized_image
882
- images_seq_mask += [True] * len(tokenized_image)
883
- # num_image_tokens.append(len(tokenized_image))
884
-
885
-
886
- """process the last text split"""
887
- tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
888
- tokenized_str += tokenized_sep
889
- images_seq_mask += [False] * len(tokenized_sep)
890
-
891
- """add the bos tokens"""
892
- bos_id = 0
893
- tokenized_str = [bos_id] + tokenized_str
894
- images_seq_mask = [False] + images_seq_mask
895
-
896
-
897
-
898
- input_ids = torch.LongTensor(tokenized_str)
899
-
900
- images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
901
-
902
-
903
- if len(images_list) == 0:
904
- images_ori = torch.zeros((1, 3, image_size, image_size))
905
- images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
906
- images_crop = torch.zeros((1, 3, base_size, base_size))
907
-
908
- else:
909
- images_ori = torch.stack(images_list, dim=0)
910
- images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
911
- if images_crop_list:
912
- images_crop = torch.stack(images_crop_list, dim=0)
913
- else:
914
- images_crop = torch.zeros((1, 3, base_size, base_size))
915
-
916
-
917
-
918
- if not eval_mode:
919
- streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
920
- with torch.autocast("cuda", dtype=torch_dtype):
921
- with torch.no_grad():
922
- output_ids = self.generate(
923
- input_ids.unsqueeze(0).cuda(),
924
- images=[(images_crop.cuda(), images_ori.cuda())],
925
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
926
- images_spatial_crop = images_spatial_crop,
927
- # do_sample=False,
928
- # num_beams = 1,
929
- temperature=0.0,
930
- eos_token_id=tokenizer.eos_token_id,
931
- streamer=streamer,
932
- max_new_tokens=8192,
933
- no_repeat_ngram_size = 20,
934
- use_cache = True
935
- )
936
-
937
- else:
938
- with torch.autocast("cuda", dtype=torch_dtype):
939
- with torch.no_grad():
940
- output_ids = self.generate(
941
- input_ids.unsqueeze(0).cuda(),
942
- images=[(images_crop.cuda(), images_ori.cuda())],
943
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
944
- images_spatial_crop = images_spatial_crop,
945
- # do_sample=False,
946
- # num_beams = 1,
947
- temperature=0.0,
948
- eos_token_id=tokenizer.eos_token_id,
949
- max_new_tokens=8192,
950
- no_repeat_ngram_size = 35,
951
- use_cache = True
952
- )
953
-
954
-
955
- if '<image>' in conversation[0]['content'] and eval_mode:
956
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
957
- stop_str = '<|end▁of▁sentence|>'
958
- if outputs.endswith(stop_str):
959
- outputs = outputs[:-len(stop_str)]
960
- # re_match
961
- outputs = outputs.strip()
962
-
963
- return outputs
964
-
965
- if '<image>' in conversation[0]['content'] and test_compress:
966
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
967
- pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
968
- print('='*50)
969
- print('image size: ', (w, h))
970
- print('valid image tokens: ', int(valid_img_tokens))
971
- print('output texts tokens (valid): ', pure_texts_outputs_token_length)
972
- print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
973
- print('='*50)
974
-
975
-
976
- if '<image>' in conversation[0]['content'] and save_results:
977
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
978
- stop_str = '<|end▁of▁sentence|>'
979
-
980
- print('='*15 + 'save results:' + '='*15)
981
-
982
- # # # # conv.messages[-1][-1] = outputs
983
- if outputs.endswith(stop_str):
984
- outputs = outputs[:-len(stop_str)]
985
- outputs = outputs.strip()
986
-
987
- matches_ref, matches_images, mathes_other = re_match(outputs)
988
- # print(matches_ref)
989
- result = process_image_with_refs(image_draw, matches_ref, output_path)
990
-
991
-
992
- for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
993
- outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n')
994
-
995
- for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
996
- outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
997
-
998
-
999
- # if 'structural formula' in conversation[0]['content']:
1000
- # outputs = '<smiles>' + outputs + '</smiles>'
1001
- with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
1002
- afile.write(outputs)
1003
-
1004
- if 'line_type' in outputs:
1005
- import matplotlib.pyplot as plt
1006
- lines = eval(outputs)['Line']['line']
1007
-
1008
- line_type = eval(outputs)['Line']['line_type']
1009
- # print(lines)
1010
-
1011
- endpoints = eval(outputs)['Line']['line_endpoint']
1012
-
1013
- fig, ax = plt.subplots(figsize=(3,3), dpi=200)
1014
- ax.set_xlim(-15, 15)
1015
- ax.set_ylim(-15, 15)
1016
-
1017
- for idx, line in enumerate(lines):
1018
- try:
1019
- p0 = eval(line.split(' -- ')[0])
1020
- p1 = eval(line.split(' -- ')[-1])
1021
-
1022
- if line_type[idx] == '--':
1023
- ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
1024
- else:
1025
- ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
1026
-
1027
- ax.scatter(p0[0], p0[1], s=5, color = 'k')
1028
- ax.scatter(p1[0], p1[1], s=5, color = 'k')
1029
- except:
1030
- pass
1031
-
1032
- for endpoint in endpoints:
1033
-
1034
- label = endpoint.split(': ')[0]
1035
- (x, y) = eval(endpoint.split(': ')[1])
1036
- ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
1037
- fontsize=5, fontweight='light')
1038
-
1039
-
1040
- plt.savefig(f'{output_path}/geo.jpg')
1041
- plt.close()
1042
-
1043
- result.save(f"{output_path}/result_with_boxes.jpg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.ipynb_checkpoints/modeling_deepseekv2-checkpoint.py DELETED
@@ -1,1996 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """ PyTorch DeepSeek model and compatible with both DeepSeekV2 and DeepSeekV3"""
21
- import math
22
- import warnings
23
- from typing import List, Optional, Tuple, Union
24
- import numpy as np
25
-
26
- import torch
27
- import torch.nn.functional as F
28
- import torch.utils.checkpoint
29
- import torch.distributed as dist
30
- from einops import repeat
31
- from torch import nn
32
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
-
34
- from transformers.activations import ACT2FN
35
- from transformers.cache_utils import Cache, DynamicCache
36
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
- try:
38
- from transformers.models.llama.modeling_llama import LlamaAttention
39
- except:
40
- LlamaAttention = None
41
- try:
42
- from transformers.models.llama.modeling_llama import LlamaFlashAttention2
43
- except:
44
- LlamaFlashAttention2 = None
45
- from transformers.modeling_outputs import (
46
- BaseModelOutputWithPast,
47
- CausalLMOutputWithPast,
48
- SequenceClassifierOutputWithPast,
49
- )
50
- from transformers.modeling_utils import PreTrainedModel
51
- from transformers.pytorch_utils import (
52
- ALL_LAYERNORM_LAYERS,
53
- is_torch_greater_or_equal_than_1_13,
54
- )
55
- from transformers.utils import (
56
- add_start_docstrings,
57
- add_start_docstrings_to_model_forward,
58
- is_flash_attn_2_available,
59
- is_flash_attn_greater_or_equal_2_10,
60
- logging,
61
- replace_return_docstrings,
62
- )
63
- from transformers.utils.import_utils import is_torch_fx_available
64
-
65
- from .configuration_deepseek_v2 import DeepseekV2Config
66
-
67
- if is_flash_attn_2_available():
68
- from flash_attn import flash_attn_func, flash_attn_varlen_func
69
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
70
-
71
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
72
- # It means that the function will not be traced through and simply appear as a node in the graph.
73
- if is_torch_fx_available():
74
- if not is_torch_greater_or_equal_than_1_13:
75
- import torch.fx
76
-
77
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
78
-
79
- logger = logging.get_logger(__name__)
80
-
81
- _CONFIG_FOR_DOC = "DeepseekV2Config"
82
-
83
-
84
- def _get_unpad_data(attention_mask):
85
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
86
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
87
- max_seqlen_in_batch = seqlens_in_batch.max().item()
88
- cu_seqlens = F.pad(
89
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
90
- )
91
- return (
92
- indices,
93
- cu_seqlens,
94
- max_seqlen_in_batch,
95
- )
96
-
97
-
98
- class DeepseekV2RMSNorm(nn.Module):
99
- def __init__(self, hidden_size, eps=1e-6):
100
- """
101
- DeepseekV2RMSNorm is equivalent to T5LayerNorm
102
- """
103
- super().__init__()
104
- self.weight = nn.Parameter(torch.ones(hidden_size))
105
- self.variance_epsilon = eps
106
-
107
- def forward(self, hidden_states):
108
- input_dtype = hidden_states.dtype
109
- hidden_states = hidden_states.to(torch.float32)
110
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
111
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
112
- return self.weight * hidden_states.to(input_dtype)
113
-
114
-
115
- ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
116
-
117
-
118
-
119
-
120
- class DeepseekV2RotaryEmbedding(nn.Module):
121
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
122
- super().__init__()
123
-
124
- self.dim = dim
125
- self.max_position_embeddings = max_position_embeddings
126
- self.base = base
127
- inv_freq = 1.0 / (
128
- self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
129
- )
130
- self.register_buffer("inv_freq", inv_freq, persistent=False)
131
-
132
- # Build here to make `torch.jit.trace` work.
133
- self._set_cos_sin_cache(
134
- seq_len=max_position_embeddings,
135
- device=self.inv_freq.device,
136
- dtype=torch.get_default_dtype(),
137
- )
138
- self.max_seq_len_cached = None
139
-
140
- def _set_cos_sin_cache(self, seq_len, device, dtype):
141
- self.max_seq_len_cached = seq_len
142
- t = torch.arange(
143
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
144
- )
145
-
146
- freqs = torch.outer(t, self.inv_freq.to(t.device))
147
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
- emb = torch.cat((freqs, freqs), dim=-1)
149
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
150
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
151
-
152
- def forward(self, x, seq_len=None):
153
- # x: [bs, num_attention_heads, seq_len, head_size]
154
- if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
155
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
156
-
157
- return (
158
- self.cos_cached[:seq_len].to(dtype=x.dtype),
159
- self.sin_cached[:seq_len].to(dtype=x.dtype),
160
- )
161
-
162
-
163
- # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2
164
- class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
165
- """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
166
-
167
- def __init__(
168
- self,
169
- dim,
170
- max_position_embeddings=2048,
171
- base=10000,
172
- device=None,
173
- scaling_factor=1.0,
174
- ):
175
- self.scaling_factor = scaling_factor
176
- super().__init__(dim, max_position_embeddings, base, device)
177
-
178
- def _set_cos_sin_cache(self, seq_len, device, dtype):
179
- self.max_seq_len_cached = seq_len
180
- t = torch.arange(
181
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
182
- )
183
- t = t / self.scaling_factor
184
-
185
- freqs = torch.outer(t, self.inv_freq)
186
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
187
- emb = torch.cat((freqs, freqs), dim=-1)
188
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
189
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
190
-
191
-
192
- # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2
193
- class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
194
- """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
195
-
196
- def __init__(
197
- self,
198
- dim,
199
- max_position_embeddings=2048,
200
- base=10000,
201
- device=None,
202
- scaling_factor=1.0,
203
- ):
204
- self.scaling_factor = scaling_factor
205
- super().__init__(dim, max_position_embeddings, base, device)
206
-
207
- def _set_cos_sin_cache(self, seq_len, device, dtype):
208
- self.max_seq_len_cached = seq_len
209
-
210
- if seq_len > self.max_position_embeddings:
211
- base = self.base * (
212
- (self.scaling_factor * seq_len / self.max_position_embeddings)
213
- - (self.scaling_factor - 1)
214
- ) ** (self.dim / (self.dim - 2))
215
- inv_freq = 1.0 / (
216
- base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
217
- )
218
- self.register_buffer("inv_freq", inv_freq, persistent=False)
219
-
220
- t = torch.arange(
221
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
222
- )
223
-
224
- freqs = torch.outer(t, self.inv_freq)
225
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
226
- emb = torch.cat((freqs, freqs), dim=-1)
227
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
228
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
229
-
230
-
231
- # Inverse dim formula to find dim based on number of rotations
232
- def yarn_find_correction_dim(
233
- num_rotations, dim, base=10000, max_position_embeddings=2048
234
- ):
235
- return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
236
- 2 * math.log(base)
237
- )
238
-
239
-
240
- # Find dim range bounds based on rotations
241
- def yarn_find_correction_range(
242
- low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
243
- ):
244
- low = math.floor(
245
- yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
246
- )
247
- high = math.ceil(
248
- yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
249
- )
250
- return max(low, 0), min(high, dim - 1) # Clamp values just in case
251
-
252
-
253
- def yarn_get_mscale(scale=1, mscale=1):
254
- if scale <= 1:
255
- return 1.0
256
- return 0.1 * mscale * math.log(scale) + 1.0
257
-
258
-
259
- def yarn_linear_ramp_mask(min, max, dim):
260
- if min == max:
261
- max += 0.001 # Prevent singularity
262
-
263
- linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
264
- ramp_func = torch.clamp(linear_func, 0, 1)
265
- return ramp_func
266
-
267
-
268
- class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding):
269
-
270
- def __init__(
271
- self,
272
- dim,
273
- max_position_embeddings=2048,
274
- base=10000,
275
- device=None,
276
- scaling_factor=1.0,
277
- original_max_position_embeddings=4096,
278
- beta_fast=32,
279
- beta_slow=1,
280
- mscale=1,
281
- mscale_all_dim=0,
282
- ):
283
- self.scaling_factor = scaling_factor
284
- self.original_max_position_embeddings = original_max_position_embeddings
285
- self.beta_fast = beta_fast
286
- self.beta_slow = beta_slow
287
- self.mscale = mscale
288
- self.mscale_all_dim = mscale_all_dim
289
- super().__init__(dim, max_position_embeddings, base, device)
290
-
291
- def _set_cos_sin_cache(self, seq_len, device, dtype):
292
- self.max_seq_len_cached = seq_len
293
- dim = self.dim
294
-
295
- freq_extra = 1.0 / (
296
- self.base
297
- ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
298
- )
299
- freq_inter = 1.0 / (
300
- self.scaling_factor
301
- * self.base
302
- ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
303
- )
304
-
305
- low, high = yarn_find_correction_range(
306
- self.beta_fast,
307
- self.beta_slow,
308
- dim,
309
- self.base,
310
- self.original_max_position_embeddings,
311
- )
312
- inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
313
- device=device, dtype=torch.float32
314
- )
315
- inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
316
- self.register_buffer("inv_freq", inv_freq, persistent=False)
317
-
318
- t = torch.arange(seq_len, device=device, dtype=torch.float32)
319
-
320
- freqs = torch.outer(t, inv_freq)
321
-
322
- _mscale = float(
323
- yarn_get_mscale(self.scaling_factor, self.mscale)
324
- / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
325
- )
326
-
327
- emb = torch.cat((freqs, freqs), dim=-1)
328
- self.register_buffer(
329
- "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
330
- )
331
- self.register_buffer(
332
- "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
333
- )
334
-
335
-
336
- # Copied from transformers.models.llama.modeling_llama.rotate_half
337
- def rotate_half(x):
338
- """Rotates half the hidden dims of the input."""
339
- x1 = x[..., : x.shape[-1] // 2]
340
- x2 = x[..., x.shape[-1] // 2 :]
341
- return torch.cat((-x2, x1), dim=-1)
342
-
343
-
344
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
345
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
346
- """Applies Rotary Position Embedding to the query and key tensors.
347
-
348
- Args:
349
- q (`torch.Tensor`): The query tensor.
350
- k (`torch.Tensor`): The key tensor.
351
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
352
- sin (`torch.Tensor`): The sine part of the rotary embedding.
353
- position_ids (`torch.Tensor`):
354
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
355
- used to pass offsetted position ids when working with a KV-cache.
356
- unsqueeze_dim (`int`, *optional*, defaults to 1):
357
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
358
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
359
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
360
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
361
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
362
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
363
- Returns:
364
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
365
- """
366
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
367
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
368
-
369
-
370
- # print()
371
-
372
- b, h, s, d = q.shape
373
- q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
374
-
375
- b, h, s, d = k.shape
376
- k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
377
-
378
- q_embed = (q * cos) + (rotate_half(q) * sin)
379
- k_embed = (k * cos) + (rotate_half(k) * sin)
380
-
381
-
382
- return q_embed, k_embed
383
-
384
-
385
- class DeepseekV2MLP(nn.Module):
386
- def __init__(self, config, hidden_size=None, intermediate_size=None):
387
- super().__init__()
388
- self.config = config
389
- self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
390
- self.intermediate_size = (
391
- config.intermediate_size if intermediate_size is None else intermediate_size
392
- )
393
-
394
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
395
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
396
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
397
- self.act_fn = ACT2FN[config.hidden_act]
398
-
399
- def forward(self, x):
400
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
401
- return down_proj
402
-
403
-
404
- class MoEGate(nn.Module):
405
- def __init__(self, config):
406
- super().__init__()
407
- self.config = config
408
- self.top_k = config.num_experts_per_tok
409
- self.n_routed_experts = config.n_routed_experts
410
- self.routed_scaling_factor = config.routed_scaling_factor
411
- self.scoring_func = config.scoring_func
412
- self.alpha = config.aux_loss_alpha
413
- self.seq_aux = config.seq_aux
414
- self.topk_method = config.topk_method
415
- self.n_group = config.n_group
416
- self.topk_group = config.topk_group
417
-
418
- # topk selection algorithm
419
- self.norm_topk_prob = config.norm_topk_prob
420
- self.gating_dim = config.hidden_size
421
- self.weight = nn.Parameter(
422
- torch.empty((self.n_routed_experts, self.gating_dim))
423
- )
424
- if self.topk_method == "noaux_tc":
425
- self.e_score_correction_bias = nn.Parameter(
426
- torch.empty((self.n_routed_experts))
427
- )
428
- self.reset_parameters()
429
-
430
- def reset_parameters(self) -> None:
431
- import torch.nn.init as init
432
-
433
- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
434
-
435
- def forward(self, hidden_states):
436
- bsz, seq_len, h = hidden_states.shape
437
- ### compute gating score
438
- hidden_states = hidden_states.view(-1, h)
439
- logits = F.linear(
440
- hidden_states.type(torch.float32), self.weight.type(torch.float32), None
441
- )
442
- if self.scoring_func == "softmax":
443
- scores = logits.softmax(dim=-1, dtype=torch.float32)
444
- elif self.scoring_func == "sigmoid":
445
- scores = logits.sigmoid()
446
- else:
447
- raise NotImplementedError(
448
- f"insupportable scoring function for MoE gating: {self.scoring_func}"
449
- )
450
-
451
- ### select top-k experts
452
- if self.topk_method == "greedy":
453
- topk_weight, topk_idx = torch.topk(
454
- scores, k=self.top_k, dim=-1, sorted=False
455
- )
456
- elif self.topk_method == "group_limited_greedy":
457
- group_scores = (
458
- scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values
459
- ) # [n, n_group]
460
- group_idx = torch.topk(
461
- group_scores, k=self.topk_group, dim=-1, sorted=False
462
- )[
463
- 1
464
- ] # [n, top_k_group]
465
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
466
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
467
- score_mask = (
468
- group_mask.unsqueeze(-1)
469
- .expand(
470
- bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
471
- )
472
- .reshape(bsz * seq_len, -1)
473
- ) # [n, e]
474
- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
475
- topk_weight, topk_idx = torch.topk(
476
- tmp_scores, k=self.top_k, dim=-1, sorted=False
477
- )
478
- elif self.topk_method == "noaux_tc":
479
- assert not self.training
480
- scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
481
- group_scores = (
482
- scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
483
- ) # [n, n_group]
484
- group_idx = torch.topk(
485
- group_scores, k=self.topk_group, dim=-1, sorted=False
486
- )[
487
- 1
488
- ] # [n, top_k_group]
489
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
490
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
491
- score_mask = (
492
- group_mask.unsqueeze(-1)
493
- .expand(
494
- bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
495
- )
496
- .reshape(bsz * seq_len, -1)
497
- ) # [n, e]
498
- tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
499
- _, topk_idx = torch.topk(
500
- tmp_scores, k=self.top_k, dim=-1, sorted=False
501
- )
502
- topk_weight = scores.gather(1, topk_idx)
503
-
504
- ### norm gate to sum 1
505
- if self.top_k > 1 and self.norm_topk_prob:
506
- denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
507
- topk_weight = topk_weight / denominator * self.routed_scaling_factor
508
- else:
509
- topk_weight = topk_weight * self.routed_scaling_factor
510
- ### expert-level computation auxiliary loss
511
- if self.training and self.alpha > 0.0:
512
- scores_for_aux = scores
513
- aux_topk = self.top_k
514
- # always compute aux loss based on the naive greedy topk method
515
- topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
516
- if self.seq_aux:
517
- scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
518
- ce = torch.zeros(
519
- bsz, self.n_routed_experts, device=hidden_states.device
520
- )
521
- ce.scatter_add_(
522
- 1,
523
- topk_idx_for_aux_loss,
524
- torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
525
- ).div_(seq_len * aux_topk / self.n_routed_experts)
526
- aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
527
- dim=1
528
- ).mean() * self.alpha
529
- else:
530
- mask_ce = F.one_hot(
531
- topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
532
- )
533
- ce = mask_ce.float().mean(0)
534
- Pi = scores_for_aux.mean(0)
535
- fi = ce * self.n_routed_experts
536
- aux_loss = (Pi * fi).sum() * self.alpha
537
- else:
538
- aux_loss = None
539
- return topk_idx, topk_weight, aux_loss
540
-
541
-
542
- class AddAuxiliaryLoss(torch.autograd.Function):
543
- """
544
- The trick function of adding auxiliary (aux) loss,
545
- which includes the gradient of the aux loss during backpropagation.
546
- """
547
-
548
- @staticmethod
549
- def forward(ctx, x, loss):
550
- assert loss.numel() == 1
551
- ctx.dtype = loss.dtype
552
- ctx.required_aux_loss = loss.requires_grad
553
- return x
554
-
555
- @staticmethod
556
- def backward(ctx, grad_output):
557
- grad_loss = None
558
- if ctx.required_aux_loss:
559
- grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
560
- return grad_output, grad_loss
561
-
562
-
563
- class DeepseekV2MoE(nn.Module):
564
- """
565
- A mixed expert module containing shared experts.
566
- """
567
-
568
- def __init__(self, config):
569
- super().__init__()
570
- self.config = config
571
- self.num_experts_per_tok = config.num_experts_per_tok
572
-
573
- if hasattr(config, "ep_size") and config.ep_size > 1:
574
- assert config.ep_size == dist.get_world_size()
575
- self.ep_size = config.ep_size
576
- self.experts_per_rank = config.n_routed_experts // config.ep_size
577
- self.ep_rank = dist.get_rank()
578
- self.experts = nn.ModuleList(
579
- [
580
- (
581
- DeepseekV2MLP(
582
- config, intermediate_size=config.moe_intermediate_size
583
- )
584
- if i >= self.ep_rank * self.experts_per_rank
585
- and i < (self.ep_rank + 1) * self.experts_per_rank
586
- else None
587
- )
588
- for i in range(config.n_routed_experts)
589
- ]
590
- )
591
- else:
592
- self.ep_size = 1
593
- self.experts_per_rank = config.n_routed_experts
594
- self.ep_rank = 0
595
- self.experts = nn.ModuleList(
596
- [
597
- DeepseekV2MLP(
598
- config, intermediate_size=config.moe_intermediate_size
599
- )
600
- for i in range(config.n_routed_experts)
601
- ]
602
- )
603
- self.gate = MoEGate(config)
604
- if config.n_shared_experts is not None:
605
- intermediate_size = config.moe_intermediate_size * config.n_shared_experts
606
- self.shared_experts = DeepseekV2MLP(
607
- config=config, intermediate_size=intermediate_size
608
- )
609
-
610
- def forward(self, hidden_states):
611
- identity = hidden_states
612
- orig_shape = hidden_states.shape
613
- topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
614
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
615
- flat_topk_idx = topk_idx.view(-1)
616
- if self.training:
617
- hidden_states = hidden_states.repeat_interleave(
618
- self.num_experts_per_tok, dim=0
619
- )
620
- y = torch.empty_like(hidden_states)
621
- for i, expert in enumerate(self.experts):
622
- y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
623
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
624
- y = y.to(hidden_states.dtype).view(*orig_shape)
625
- y = AddAuxiliaryLoss.apply(y, aux_loss)
626
- else:
627
- y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
628
- if self.config.n_shared_experts is not None:
629
- y = y + self.shared_experts(identity)
630
- return y
631
-
632
- @torch.no_grad()
633
- def moe_infer(self, x, topk_ids, topk_weight):
634
- cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
635
- cnts.scatter_(1, topk_ids, 1)
636
- tokens_per_expert = cnts.sum(dim=0)
637
- idxs = topk_ids.view(-1).argsort()
638
- sorted_tokens = x[idxs // topk_ids.shape[1]]
639
- sorted_tokens_shape = sorted_tokens.shape
640
- if self.ep_size > 1:
641
- tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
642
- tokens_per_expert_group = tokens_per_expert.new_empty(
643
- tokens_per_expert.shape[0]
644
- )
645
- dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
646
- output_splits = (
647
- tokens_per_expert_group.view(self.ep_size, -1)
648
- .sum(1)
649
- .cpu()
650
- .numpy()
651
- .tolist()
652
- )
653
- gathered_tokens = sorted_tokens.new_empty(
654
- tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
655
- )
656
- input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
657
- dist.all_to_all(
658
- list(gathered_tokens.split(output_splits)),
659
- list(sorted_tokens.split(input_split_sizes)),
660
- )
661
- tokens_per_expert_post_gather = tokens_per_expert_group.view(
662
- self.ep_size, self.experts_per_rank
663
- ).sum(dim=0)
664
- gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
665
- s = 0
666
- for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
667
- gatherd_idxs[s : s + k] = i % self.experts_per_rank
668
- s += k
669
- gatherd_idxs = gatherd_idxs.argsort()
670
- sorted_tokens = gathered_tokens[gatherd_idxs]
671
- tokens_per_expert = tokens_per_expert_post_gather
672
- tokens_per_expert = tokens_per_expert.cpu().numpy()
673
-
674
- outputs = []
675
- start_idx = 0
676
- for i, num_tokens in enumerate(tokens_per_expert):
677
- end_idx = start_idx + num_tokens
678
- if num_tokens == 0:
679
- continue
680
- expert = self.experts[i + self.ep_rank * self.experts_per_rank]
681
- tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
682
- expert_out = expert(tokens_for_this_expert)
683
- outputs.append(expert_out)
684
- start_idx = end_idx
685
-
686
- outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
687
- if self.ep_size > 1:
688
- new_x = torch.empty_like(outs)
689
- new_x[gatherd_idxs] = outs
690
- gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
691
- dist.all_to_all(
692
- list(gathered_tokens.split(input_split_sizes)),
693
- list(new_x.split(output_splits)),
694
- )
695
- outs = gathered_tokens
696
-
697
- new_x = torch.empty_like(outs)
698
- new_x[idxs] = outs
699
- final_out = (
700
- new_x.view(*topk_ids.shape, -1)
701
- .type(topk_weight.dtype)
702
- .mul_(topk_weight.unsqueeze(dim=-1))
703
- .sum(dim=1)
704
- .type(new_x.dtype)
705
- )
706
- return final_out
707
-
708
-
709
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
710
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
711
- """
712
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
713
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
714
- """
715
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
716
- if n_rep == 1:
717
- return hidden_states
718
- hidden_states = hidden_states[:, :, None, :, :].expand(
719
- batch, num_key_value_heads, n_rep, slen, head_dim
720
- )
721
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
722
-
723
-
724
- # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2
725
- class DeepseekV2Attention(nn.Module):
726
- """Multi-headed attention from 'Attention Is All You Need' paper"""
727
-
728
- def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):
729
- super().__init__()
730
- self.config = config
731
- self.layer_idx = layer_idx
732
- if layer_idx is None:
733
- logger.warning_once(
734
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
735
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
736
- "when creating this class."
737
- )
738
-
739
- self.attention_dropout = config.attention_dropout
740
- self.hidden_size = config.hidden_size
741
- self.num_heads = config.num_attention_heads
742
-
743
- self.max_position_embeddings = config.max_position_embeddings
744
- self.rope_theta = config.rope_theta
745
- self.q_lora_rank = config.q_lora_rank
746
- self.qk_rope_head_dim = config.qk_rope_head_dim
747
- self.kv_lora_rank = config.kv_lora_rank
748
- self.v_head_dim = config.v_head_dim
749
- self.qk_nope_head_dim = config.qk_nope_head_dim
750
- self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
751
-
752
- self.is_causal = True
753
-
754
- if self.q_lora_rank is None:
755
- self.q_proj = nn.Linear(
756
- self.hidden_size, self.num_heads * self.q_head_dim, bias=False
757
- )
758
- else:
759
- self.q_a_proj = nn.Linear(
760
- self.hidden_size, config.q_lora_rank, bias=config.attention_bias
761
- )
762
- self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
763
- self.q_b_proj = nn.Linear(
764
- config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
765
- )
766
- # config.kv_lora_rank + config.qk_rope_head_dim,
767
- self.kv_a_proj_with_mqa = nn.Linear(
768
- self.hidden_size,
769
- config.kv_lora_rank + config.qk_rope_head_dim,
770
- bias=config.attention_bias,
771
- )
772
- self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
773
- self.kv_b_proj = nn.Linear(
774
- config.kv_lora_rank,
775
- self.num_heads
776
- * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
777
- bias=False,
778
- )
779
-
780
- self.o_proj = nn.Linear(
781
- self.num_heads * self.v_head_dim,
782
- self.hidden_size,
783
- bias=config.attention_bias,
784
- )
785
- self._init_rope()
786
-
787
- self.softmax_scale = self.q_head_dim ** (-0.5)
788
- if self.config.rope_scaling is not None:
789
- mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
790
- scaling_factor = self.config.rope_scaling["factor"]
791
- if mscale_all_dim:
792
- mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
793
- self.softmax_scale = self.softmax_scale * mscale * mscale
794
-
795
- def _init_rope(self):
796
- if self.config.rope_scaling is None:
797
- self.rotary_emb = DeepseekV2RotaryEmbedding(
798
- self.qk_rope_head_dim,
799
- max_position_embeddings=self.max_position_embeddings,
800
- base=self.rope_theta,
801
- )
802
- # self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
803
- # self.qk_rope_head_dim,
804
- # max_position_embeddings=self.max_position_embeddings,
805
- # scaling_factor=scaling_factor,
806
- # base=self.rope_theta,
807
- # )
808
- else:
809
- scaling_type = self.config.rope_scaling["type"]
810
- scaling_factor = self.config.rope_scaling["factor"]
811
- if scaling_type == "linear":
812
- self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
813
- self.qk_rope_head_dim,
814
- max_position_embeddings=self.max_position_embeddings,
815
- scaling_factor=scaling_factor,
816
- base=self.rope_theta,
817
- )
818
- elif scaling_type == "dynamic":
819
- self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(
820
- self.qk_rope_head_dim,
821
- max_position_embeddings=self.max_position_embeddings,
822
- scaling_factor=scaling_factor,
823
- base=self.rope_theta,
824
- )
825
- elif scaling_type == "yarn":
826
- kwargs = {
827
- key: self.config.rope_scaling[key]
828
- for key in [
829
- "original_max_position_embeddings",
830
- "beta_fast",
831
- "beta_slow",
832
- "mscale",
833
- "mscale_all_dim",
834
- ]
835
- if key in self.config.rope_scaling
836
- }
837
- self.rotary_emb = DeepseekV2YarnRotaryEmbedding(
838
- self.qk_rope_head_dim,
839
- max_position_embeddings=self.max_position_embeddings,
840
- scaling_factor=scaling_factor,
841
- base=self.rope_theta,
842
- **kwargs,
843
- )
844
- else:
845
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
846
-
847
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
848
- return (
849
- tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)
850
- .transpose(1, 2)
851
- .contiguous()
852
- )
853
-
854
- def forward(
855
- self,
856
- hidden_states: torch.Tensor,
857
- attention_mask: Optional[torch.Tensor] = None,
858
- position_ids: Optional[torch.LongTensor] = None,
859
- past_key_value: Optional[Cache] = None,
860
- output_attentions: bool = False,
861
- use_cache: bool = False,
862
- **kwargs,
863
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
864
- if "padding_mask" in kwargs:
865
- warnings.warn(
866
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
867
- )
868
- bsz, q_len, _ = hidden_states.size()
869
-
870
- if self.q_lora_rank is None:
871
- q = self.q_proj(hidden_states)
872
- else:
873
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
874
- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
875
-
876
-
877
- q_nope, q_pe = torch.split(
878
- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
879
- )
880
-
881
- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
882
- compressed_kv, k_pe = torch.split(
883
- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
884
- )
885
- compressed_kv = self.kv_a_layernorm(compressed_kv)
886
- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
887
-
888
- kv_seq_len = k_pe.shape[-2]
889
- if past_key_value is not None:
890
- if self.layer_idx is None:
891
- raise ValueError(
892
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
893
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
894
- "with a layer index."
895
- )
896
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
897
-
898
- cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
899
- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
900
-
901
- if past_key_value is not None:
902
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
903
- compressed_kv = compressed_kv.unsqueeze(1)
904
- k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
905
- compressed_kv = compressed_kv.squeeze(1)
906
-
907
- kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
908
- q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :]
909
- out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
910
-
911
- q_nope = torch.matmul(q_nope, q_absorb)
912
- attn_weights = (torch.matmul(q_pe, k_pe.mT) +
913
- torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
914
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
915
- raise ValueError(
916
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
917
- f" {attn_weights.size()}"
918
- )
919
- assert attention_mask is not None
920
- if attention_mask is not None:
921
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
922
- raise ValueError(
923
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
924
- )
925
- attn_weights = attn_weights + attention_mask
926
-
927
- # upcast attention to fp32
928
- attn_weights = nn.functional.softmax(
929
- attn_weights, dim=-1, dtype=torch.float32
930
- ).to(q_pe.dtype)
931
- attn_weights = nn.functional.dropout(
932
- attn_weights, p=self.attention_dropout, training=self.training
933
- )
934
- attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
935
-
936
- attn_output = torch.matmul(attn_output, out_absorb.mT)
937
-
938
- if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
939
- raise ValueError(
940
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
941
- f" {attn_output.size()}"
942
- )
943
-
944
- attn_output = attn_output.transpose(1, 2).contiguous()
945
-
946
- attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
947
-
948
- attn_output = self.o_proj(attn_output)
949
-
950
- if not output_attentions:
951
- attn_weights = None
952
-
953
- return attn_output, attn_weights, past_key_value
954
-
955
-
956
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2
957
- class DeepseekV2FlashAttention2(DeepseekV2Attention):
958
- """
959
- DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays
960
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
961
- flash attention and deal with padding tokens in case the input contains any of them.
962
- """
963
-
964
- def __init__(self, *args, **kwargs):
965
- super().__init__(*args, **kwargs)
966
-
967
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
968
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
969
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
970
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
971
-
972
- def forward(
973
- self,
974
- hidden_states: torch.Tensor,
975
- attention_mask: Optional[torch.LongTensor] = None,
976
- position_ids: Optional[torch.LongTensor] = None,
977
- past_key_value: Optional[Cache] = None,
978
- output_attentions: bool = False,
979
- use_cache: bool = False,
980
- **kwargs,
981
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
982
- # DeepseekV2FlashAttention2 attention does not support output_attentions
983
- if "padding_mask" in kwargs:
984
- warnings.warn(
985
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
986
- )
987
-
988
- # overwrite attention_mask with padding_mask
989
- attention_mask = kwargs.pop("padding_mask")
990
-
991
- output_attentions = False
992
-
993
- bsz, q_len, _ = hidden_states.size()
994
-
995
- if self.q_lora_rank is None:
996
- q = self.q_proj(hidden_states)
997
- else:
998
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
999
- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
1000
- q_nope, q_pe = torch.split(
1001
- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
1002
- )
1003
-
1004
- # Flash attention requires the input to have the shape
1005
- # batch_size x seq_length x head_dim x hidden_dim
1006
- # therefore we just need to keep the original shape
1007
- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
1008
- compressed_kv, k_pe = torch.split(
1009
- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
1010
- )
1011
- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
1012
- kv = (
1013
- self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
1014
- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
1015
- .transpose(1, 2)
1016
- )
1017
-
1018
- k_nope, value_states = torch.split(
1019
- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
1020
- )
1021
- kv_seq_len = value_states.shape[-2]
1022
-
1023
- kv_seq_len = value_states.shape[-2]
1024
- if past_key_value is not None:
1025
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1026
-
1027
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1028
- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
1029
-
1030
- query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
1031
- query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
1032
- query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
1033
-
1034
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
1035
- key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
1036
- key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
1037
-
1038
- if self.q_head_dim != self.v_head_dim:
1039
- value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
1040
-
1041
- # TODO: support compressed_kv for kv_cache (instead of key_states, value_states) in flash_attention version
1042
- if past_key_value is not None:
1043
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1044
- key_states, value_states = past_key_value.update(
1045
- key_states, value_states, self.layer_idx, cache_kwargs
1046
- )
1047
-
1048
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
1049
- # to be able to avoid many of these transpose/reshape/view.
1050
- query_states = query_states.transpose(1, 2)
1051
- key_states = key_states.transpose(1, 2)
1052
- value_states = value_states.transpose(1, 2)
1053
-
1054
- dropout_rate = self.attention_dropout if self.training else 0.0
1055
-
1056
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
1057
- # therefore the input hidden states gets silently casted in float32. Hence, we need
1058
- # cast them back in the correct dtype just to be sure everything works as expected.
1059
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
1060
- # in fp32. (DeepseekV2RMSNorm handles it correctly)
1061
-
1062
- input_dtype = query_states.dtype
1063
- if input_dtype == torch.float32:
1064
- # Handle the case where the model is quantized
1065
- if hasattr(self.config, "_pre_quantization_dtype"):
1066
- target_dtype = self.config._pre_quantization_dtype
1067
- elif torch.is_autocast_enabled():
1068
- target_dtype = torch.get_autocast_gpu_dtype()
1069
- else:
1070
- target_dtype = (
1071
- self.q_proj.weight.dtype
1072
- if self.q_lora_rank is None
1073
- else self.q_a_proj.weight.dtype
1074
- )
1075
-
1076
- logger.warning_once(
1077
- f"The input hidden states seems to be silently casted in float32, this might be related to"
1078
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1079
- f" {target_dtype}."
1080
- )
1081
-
1082
- query_states = query_states.to(target_dtype)
1083
- key_states = key_states.to(target_dtype)
1084
- value_states = value_states.to(target_dtype)
1085
-
1086
- attn_output = self._flash_attention_forward(
1087
- query_states,
1088
- key_states,
1089
- value_states,
1090
- attention_mask,
1091
- q_len,
1092
- dropout=dropout_rate,
1093
- softmax_scale=self.softmax_scale,
1094
- )
1095
- if self.q_head_dim != self.v_head_dim:
1096
- attn_output = attn_output[:, :, :, : self.v_head_dim]
1097
-
1098
- attn_output = attn_output.reshape(
1099
- bsz, q_len, self.num_heads * self.v_head_dim
1100
- ).contiguous()
1101
- attn_output = self.o_proj(attn_output)
1102
-
1103
- if not output_attentions:
1104
- attn_weights = None
1105
-
1106
- return attn_output, attn_weights, past_key_value
1107
-
1108
- def _flash_attention_forward(
1109
- self,
1110
- query_states,
1111
- key_states,
1112
- value_states,
1113
- attention_mask,
1114
- query_length,
1115
- dropout=0.0,
1116
- softmax_scale=None,
1117
- ):
1118
- """
1119
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1120
- first unpad the input, then computes the attention scores and pad the final attention scores.
1121
-
1122
- Args:
1123
- query_states (`torch.Tensor`):
1124
- Input query states to be passed to Flash Attention API
1125
- key_states (`torch.Tensor`):
1126
- Input key states to be passed to Flash Attention API
1127
- value_states (`torch.Tensor`):
1128
- Input value states to be passed to Flash Attention API
1129
- attention_mask (`torch.Tensor`):
1130
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1131
- position of padding tokens and 1 for the position of non-padding tokens.
1132
- dropout (`int`, *optional*):
1133
- Attention dropout
1134
- softmax_scale (`float`, *optional*):
1135
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1136
- """
1137
- if not self._flash_attn_uses_top_left_mask:
1138
- causal = self.is_causal
1139
- else:
1140
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.
1141
- causal = self.is_causal and query_length != 1
1142
-
1143
- # Contains at least one padding token in the sequence
1144
- if attention_mask is not None:
1145
- batch_size = query_states.shape[0]
1146
- (
1147
- query_states,
1148
- key_states,
1149
- value_states,
1150
- indices_q,
1151
- cu_seq_lens,
1152
- max_seq_lens,
1153
- ) = self._upad_input(
1154
- query_states, key_states, value_states, attention_mask, query_length
1155
- )
1156
-
1157
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1158
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1159
-
1160
- attn_output_unpad = flash_attn_varlen_func(
1161
- query_states,
1162
- key_states,
1163
- value_states,
1164
- cu_seqlens_q=cu_seqlens_q,
1165
- cu_seqlens_k=cu_seqlens_k,
1166
- max_seqlen_q=max_seqlen_in_batch_q,
1167
- max_seqlen_k=max_seqlen_in_batch_k,
1168
- dropout_p=dropout,
1169
- softmax_scale=softmax_scale,
1170
- causal=causal,
1171
- )
1172
-
1173
- attn_output = pad_input(
1174
- attn_output_unpad, indices_q, batch_size, query_length
1175
- )
1176
- else:
1177
- attn_output = flash_attn_func(
1178
- query_states,
1179
- key_states,
1180
- value_states,
1181
- dropout,
1182
- softmax_scale=softmax_scale,
1183
- causal=causal,
1184
- )
1185
-
1186
- return attn_output
1187
-
1188
- def _upad_input(
1189
- self, query_layer, key_layer, value_layer, attention_mask, query_length
1190
- ):
1191
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1192
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1193
-
1194
- key_layer = index_first_axis(
1195
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1196
- indices_k,
1197
- )
1198
- value_layer = index_first_axis(
1199
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1200
- indices_k,
1201
- )
1202
- if query_length == kv_seq_len:
1203
- query_layer = index_first_axis(
1204
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
1205
- indices_k,
1206
- )
1207
- cu_seqlens_q = cu_seqlens_k
1208
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
1209
- indices_q = indices_k
1210
- elif query_length == 1:
1211
- max_seqlen_in_batch_q = 1
1212
- cu_seqlens_q = torch.arange(
1213
- batch_size + 1, dtype=torch.int32, device=query_layer.device
1214
- ) # There is a memcpy here, that is very bad.
1215
- indices_q = cu_seqlens_q[:-1]
1216
- query_layer = query_layer.squeeze(1)
1217
- else:
1218
- # The -q_len: slice assumes left padding.
1219
- attention_mask = attention_mask[:, -query_length:]
1220
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1221
- query_layer, attention_mask
1222
- )
1223
-
1224
- return (
1225
- query_layer,
1226
- key_layer,
1227
- value_layer,
1228
- indices_q,
1229
- (cu_seqlens_q, cu_seqlens_k),
1230
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1231
- )
1232
-
1233
-
1234
- ATTENTION_CLASSES = {
1235
- "eager": DeepseekV2Attention,
1236
- "flash_attention_2": DeepseekV2FlashAttention2,
1237
-
1238
- "mla_eager": DeepseekV2Attention,
1239
- "mla_flash_attention_2": DeepseekV2FlashAttention2,
1240
-
1241
- "mha_eager": LlamaAttention,
1242
- "mha_flash_attention_2": LlamaFlashAttention2
1243
- }
1244
-
1245
-
1246
- class DeepseekV2DecoderLayer(nn.Module):
1247
- def __init__(self, config: DeepseekV2Config, layer_idx: int):
1248
- super().__init__()
1249
- self.hidden_size = config.hidden_size
1250
-
1251
-
1252
- if config.use_mla:
1253
- attn_implementation = "mla_" + config._attn_implementation
1254
- else:
1255
- attn_implementation = "mha_" + config._attn_implementation
1256
-
1257
- self.self_attn = ATTENTION_CLASSES[attn_implementation](
1258
- config=config, layer_idx=layer_idx
1259
- )
1260
-
1261
- self.mlp = (
1262
- DeepseekV2MoE(config)
1263
- if (
1264
- config.n_routed_experts is not None
1265
- and layer_idx >= config.first_k_dense_replace
1266
- and layer_idx % config.moe_layer_freq == 0
1267
- )
1268
- else DeepseekV2MLP(config)
1269
- )
1270
- self.input_layernorm = DeepseekV2RMSNorm(
1271
- config.hidden_size, eps=config.rms_norm_eps
1272
- )
1273
- self.post_attention_layernorm = DeepseekV2RMSNorm(
1274
- config.hidden_size, eps=config.rms_norm_eps
1275
- )
1276
-
1277
- def forward(
1278
- self,
1279
- hidden_states: torch.Tensor,
1280
- attention_mask: Optional[torch.Tensor] = None,
1281
- position_ids: Optional[torch.LongTensor] = None,
1282
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
1283
- output_attentions: Optional[bool] = False,
1284
- use_cache: Optional[bool] = False,
1285
- **kwargs,
1286
- ) -> Tuple[
1287
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
1288
- ]:
1289
- """
1290
- Args:
1291
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1292
- attention_mask (`torch.FloatTensor`, *optional*):
1293
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1294
- query_sequence_length, key_sequence_length)` if default attention is used.
1295
- output_attentions (`bool`, *optional*):
1296
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1297
- returned tensors for more detail.
1298
- use_cache (`bool`, *optional*):
1299
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1300
- (see `past_key_values`).
1301
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1302
- """
1303
- if "padding_mask" in kwargs:
1304
- warnings.warn(
1305
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1306
- )
1307
- residual = hidden_states
1308
-
1309
- hidden_states = self.input_layernorm(hidden_states)
1310
-
1311
- # Self Attention
1312
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
1313
- hidden_states=hidden_states,
1314
- attention_mask=attention_mask,
1315
- position_ids=position_ids,
1316
- past_key_value=past_key_value,
1317
- output_attentions=output_attentions,
1318
- use_cache=use_cache,
1319
- **kwargs,
1320
- )
1321
- hidden_states = residual + hidden_states
1322
-
1323
- # Fully Connected
1324
- residual = hidden_states
1325
- hidden_states = self.post_attention_layernorm(hidden_states)
1326
- hidden_states = self.mlp(hidden_states)
1327
- hidden_states = residual + hidden_states
1328
-
1329
- outputs = (hidden_states,)
1330
-
1331
- if output_attentions:
1332
- outputs += (self_attn_weights,)
1333
-
1334
- if use_cache:
1335
- outputs += (present_key_value,)
1336
-
1337
- return outputs
1338
-
1339
-
1340
- DeepseekV2_START_DOCSTRING = r"""
1341
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1342
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1343
- etc.)
1344
-
1345
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1346
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1347
- and behavior.
1348
-
1349
- Parameters:
1350
- config ([`DeepseekV2Config`]):
1351
- Model configuration class with all the parameters of the model. Initializing with a config file does not
1352
- load the weights associated with the model, only the configuration. Check out the
1353
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1354
- """
1355
-
1356
-
1357
- @add_start_docstrings(
1358
- "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
1359
- DeepseekV2_START_DOCSTRING,
1360
- )
1361
- class DeepseekV2PreTrainedModel(PreTrainedModel):
1362
- config_class = DeepseekV2Config
1363
- base_model_prefix = "model"
1364
- supports_gradient_checkpointing = True
1365
- _no_split_modules = ["DeepseekV2DecoderLayer"]
1366
- _skip_keys_device_placement = "past_key_values"
1367
- _supports_flash_attn_2 = True
1368
- _supports_cache_class = True
1369
-
1370
- def _init_weights(self, module):
1371
- std = self.config.initializer_range
1372
- if isinstance(module, nn.Linear):
1373
- module.weight.data.normal_(mean=0.0, std=std)
1374
- if module.bias is not None:
1375
- module.bias.data.zero_()
1376
- elif isinstance(module, nn.Embedding):
1377
- module.weight.data.normal_(mean=0.0, std=std)
1378
- if module.padding_idx is not None:
1379
- module.weight.data[module.padding_idx].zero_()
1380
-
1381
-
1382
- DeepseekV2_INPUTS_DOCSTRING = r"""
1383
- Args:
1384
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1385
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1386
- it.
1387
-
1388
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1389
- [`PreTrainedTokenizer.__call__`] for details.
1390
-
1391
- [What are input IDs?](../glossary#input-ids)
1392
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1393
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1394
-
1395
- - 1 for tokens that are **not masked**,
1396
- - 0 for tokens that are **masked**.
1397
-
1398
- [What are attention masks?](../glossary#attention-mask)
1399
-
1400
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1401
- [`PreTrainedTokenizer.__call__`] for details.
1402
-
1403
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1404
- `past_key_values`).
1405
-
1406
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1407
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1408
- information on the default strategy.
1409
-
1410
- - 1 indicates the head is **not masked**,
1411
- - 0 indicates the head is **masked**.
1412
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1413
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1414
- config.n_positions - 1]`.
1415
-
1416
- [What are position IDs?](../glossary#position-ids)
1417
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1418
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1419
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1420
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1421
-
1422
- Two formats are allowed:
1423
- - a [`~cache_utils.Cache`] instance;
1424
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1425
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1426
- cache format.
1427
-
1428
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1429
- legacy cache format will be returned.
1430
-
1431
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1432
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1433
- of shape `(batch_size, sequence_length)`.
1434
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1435
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1436
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1437
- model's internal embedding lookup matrix.
1438
- use_cache (`bool`, *optional*):
1439
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1440
- `past_key_values`).
1441
- output_attentions (`bool`, *optional*):
1442
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1443
- tensors for more detail.
1444
- output_hidden_states (`bool`, *optional*):
1445
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1446
- more detail.
1447
- return_dict (`bool`, *optional*):
1448
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1449
- """
1450
-
1451
-
1452
- @add_start_docstrings(
1453
- "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
1454
- DeepseekV2_START_DOCSTRING,
1455
- )
1456
- class DeepseekV2Model(DeepseekV2PreTrainedModel):
1457
- """
1458
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]
1459
-
1460
- Args:
1461
- config: DeepseekV2Config
1462
- """
1463
-
1464
- def __init__(self, config: DeepseekV2Config):
1465
- super().__init__(config)
1466
- self.padding_idx = config.pad_token_id
1467
- self.vocab_size = config.vocab_size
1468
-
1469
- self.embed_tokens = nn.Embedding(
1470
- config.vocab_size, config.hidden_size, self.padding_idx
1471
- )
1472
- self.layers = nn.ModuleList(
1473
- [
1474
- DeepseekV2DecoderLayer(config, layer_idx)
1475
- for layer_idx in range(config.num_hidden_layers)
1476
- ]
1477
- )
1478
- # print(config._attn_implementation)
1479
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1480
- self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1481
-
1482
- self.gradient_checkpointing = False
1483
- # Initialize weights and apply final processing
1484
- self.post_init()
1485
-
1486
- def get_input_embeddings(self):
1487
- return self.embed_tokens
1488
-
1489
- def set_input_embeddings(self, value):
1490
- self.embed_tokens = value
1491
-
1492
- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
1493
- def forward(
1494
- self,
1495
- input_ids: torch.LongTensor = None,
1496
- attention_mask: Optional[torch.Tensor] = None,
1497
- position_ids: Optional[torch.LongTensor] = None,
1498
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1499
- inputs_embeds: Optional[torch.FloatTensor] = None,
1500
- use_cache: Optional[bool] = None,
1501
- output_attentions: Optional[bool] = None,
1502
- output_hidden_states: Optional[bool] = None,
1503
- return_dict: Optional[bool] = None,
1504
- cache_position: Optional[torch.LongTensor] = None
1505
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1506
- output_attentions = (
1507
- output_attentions
1508
- if output_attentions is not None
1509
- else self.config.output_attentions
1510
- )
1511
- output_hidden_states = (
1512
- output_hidden_states
1513
- if output_hidden_states is not None
1514
- else self.config.output_hidden_states
1515
- )
1516
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1517
-
1518
- return_dict = (
1519
- return_dict if return_dict is not None else self.config.use_return_dict
1520
- )
1521
-
1522
- # retrieve input_ids and inputs_embeds
1523
- if input_ids is not None and inputs_embeds is not None:
1524
- raise ValueError(
1525
- "You cannot specify both input_ids and inputs_embeds at the same time"
1526
- )
1527
- elif input_ids is not None:
1528
- batch_size, seq_length = input_ids.shape[:2]
1529
- elif inputs_embeds is not None:
1530
- batch_size, seq_length = inputs_embeds.shape[:2]
1531
- else:
1532
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1533
-
1534
- if self.gradient_checkpointing and self.training:
1535
- if use_cache:
1536
- logger.warning_once(
1537
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
1538
- )
1539
- use_cache = False
1540
-
1541
- past_key_values_length = 0
1542
- if use_cache:
1543
- use_legacy_cache = not isinstance(past_key_values, Cache)
1544
- if use_legacy_cache:
1545
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1546
- past_key_values_length = past_key_values.get_usable_length(seq_length)
1547
-
1548
- if position_ids is None:
1549
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1550
- position_ids = torch.arange(
1551
- past_key_values_length,
1552
- seq_length + past_key_values_length,
1553
- dtype=torch.long,
1554
- device=device,
1555
- )
1556
- position_ids = position_ids.unsqueeze(0)
1557
-
1558
- if inputs_embeds is None:
1559
- inputs_embeds = self.embed_tokens(input_ids)
1560
-
1561
- if self._use_flash_attention_2:
1562
- # 2d mask is passed through the layers
1563
- attention_mask = (
1564
- attention_mask
1565
- if (attention_mask is not None and 0 in attention_mask)
1566
- else None
1567
- )
1568
- else:
1569
- # 4d mask is passed through the layers
1570
- attention_mask = _prepare_4d_causal_attention_mask(
1571
- attention_mask,
1572
- (batch_size, seq_length),
1573
- inputs_embeds,
1574
- past_key_values_length,
1575
- )
1576
-
1577
- # embed positions
1578
- hidden_states = inputs_embeds
1579
-
1580
- # decoder layers
1581
- all_hidden_states = () if output_hidden_states else None
1582
- all_self_attns = () if output_attentions else None
1583
- next_decoder_cache = None
1584
-
1585
- for decoder_layer in self.layers:
1586
- if output_hidden_states:
1587
- all_hidden_states += (hidden_states,)
1588
-
1589
- if self.gradient_checkpointing and self.training:
1590
- layer_outputs = self._gradient_checkpointing_func(
1591
- decoder_layer.__call__,
1592
- hidden_states,
1593
- attention_mask,
1594
- position_ids,
1595
- past_key_values,
1596
- output_attentions,
1597
- use_cache,
1598
- )
1599
- else:
1600
- layer_outputs = decoder_layer(
1601
- hidden_states,
1602
- attention_mask=attention_mask,
1603
- position_ids=position_ids,
1604
- past_key_value=past_key_values,
1605
- output_attentions=output_attentions,
1606
- use_cache=use_cache,
1607
- )
1608
-
1609
- hidden_states = layer_outputs[0]
1610
-
1611
- if use_cache:
1612
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1613
-
1614
- if output_attentions:
1615
- all_self_attns += (layer_outputs[1],)
1616
-
1617
- hidden_states = self.norm(hidden_states)
1618
-
1619
- # add hidden states from the last decoder layer
1620
- if output_hidden_states:
1621
- all_hidden_states += (hidden_states,)
1622
-
1623
- next_cache = None
1624
- if use_cache:
1625
- next_cache = (
1626
- next_decoder_cache.to_legacy_cache()
1627
- if use_legacy_cache
1628
- else next_decoder_cache
1629
- )
1630
- if not return_dict:
1631
- return tuple(
1632
- v
1633
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1634
- if v is not None
1635
- )
1636
- return BaseModelOutputWithPast(
1637
- last_hidden_state=hidden_states,
1638
- past_key_values=next_cache,
1639
- hidden_states=all_hidden_states,
1640
- attentions=all_self_attns,
1641
- )
1642
-
1643
-
1644
- class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1645
- _tied_weights_keys = ["lm_head.weight"]
1646
-
1647
- def __init__(self, config):
1648
- super().__init__(config)
1649
- self.model = DeepseekV2Model(config)
1650
- self.vocab_size = config.vocab_size
1651
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1652
-
1653
- # Initialize weights and apply final processing
1654
- self.post_init()
1655
-
1656
- def get_input_embeddings(self):
1657
- return self.model.embed_tokens
1658
-
1659
- def set_input_embeddings(self, value):
1660
- self.model.embed_tokens = value
1661
-
1662
- def get_output_embeddings(self):
1663
- return self.lm_head
1664
-
1665
- def set_output_embeddings(self, new_embeddings):
1666
- self.lm_head = new_embeddings
1667
-
1668
- def set_decoder(self, decoder):
1669
- self.model = decoder
1670
-
1671
- def get_decoder(self):
1672
- return self.model
1673
-
1674
- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
1675
- @replace_return_docstrings(
1676
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1677
- )
1678
- def forward(
1679
- self,
1680
- input_ids: torch.LongTensor = None,
1681
- attention_mask: Optional[torch.Tensor] = None,
1682
- position_ids: Optional[torch.LongTensor] = None,
1683
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1684
- inputs_embeds: Optional[torch.FloatTensor] = None,
1685
- labels: Optional[torch.LongTensor] = None,
1686
- use_cache: Optional[bool] = None,
1687
- output_attentions: Optional[bool] = None,
1688
- output_hidden_states: Optional[bool] = None,
1689
- return_dict: Optional[bool] = None,
1690
- cache_position: Optional[torch.LongTensor] = None
1691
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1692
- r"""
1693
- Args:
1694
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1695
- Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1696
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1697
- (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1698
-
1699
- Returns:
1700
-
1701
- Example:
1702
-
1703
- ```python
1704
- >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
1705
-
1706
- >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1707
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1708
-
1709
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1710
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1711
-
1712
- >>> # Generate
1713
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1714
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1715
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1716
- ```"""
1717
- output_attentions = (
1718
- output_attentions
1719
- if output_attentions is not None
1720
- else self.config.output_attentions
1721
- )
1722
- output_hidden_states = (
1723
- output_hidden_states
1724
- if output_hidden_states is not None
1725
- else self.config.output_hidden_states
1726
- )
1727
- return_dict = (
1728
- return_dict if return_dict is not None else self.config.use_return_dict
1729
- )
1730
-
1731
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1732
- outputs = self.model(
1733
- input_ids=input_ids,
1734
- attention_mask=attention_mask,
1735
- position_ids=position_ids,
1736
- past_key_values=past_key_values,
1737
- inputs_embeds=inputs_embeds,
1738
- use_cache=use_cache,
1739
- output_attentions=output_attentions,
1740
- output_hidden_states=output_hidden_states,
1741
- return_dict=return_dict,
1742
- cache_position=cache_position
1743
- )
1744
-
1745
- hidden_states = outputs[0]
1746
- logits = self.lm_head(hidden_states)
1747
- logits = logits.float()
1748
-
1749
- loss = None
1750
- if labels is not None:
1751
- # Shift so that tokens < n predict n
1752
- shift_logits = logits[..., :-1, :].contiguous()
1753
- shift_labels = labels[..., 1:].contiguous()
1754
- # Flatten the tokens
1755
- loss_fct = CrossEntropyLoss()
1756
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1757
- shift_labels = shift_labels.view(-1)
1758
- # Enable model parallelism
1759
- shift_labels = shift_labels.to(shift_logits.device)
1760
- loss = loss_fct(shift_logits, shift_labels)
1761
-
1762
- if not return_dict:
1763
- output = (logits,) + outputs[1:]
1764
- return (loss,) + output if loss is not None else output
1765
-
1766
- return CausalLMOutputWithPast(
1767
- loss=loss,
1768
- logits=logits,
1769
- past_key_values=outputs.past_key_values,
1770
- hidden_states=outputs.hidden_states,
1771
- attentions=outputs.attentions,
1772
- )
1773
-
1774
- def prepare_inputs_for_generation(
1775
- self,
1776
- input_ids,
1777
- past_key_values=None,
1778
- attention_mask=None,
1779
- inputs_embeds=None,
1780
- **kwargs,
1781
- ):
1782
- past_length = 0
1783
- if past_key_values is not None:
1784
- if isinstance(past_key_values, Cache):
1785
- cache_length = past_key_values.get_seq_length()
1786
- past_length = past_key_values.seen_tokens
1787
- max_cache_length = past_key_values.get_max_length()
1788
- else:
1789
- cache_length = past_length = past_key_values[0][0].shape[2]
1790
- max_cache_length = None
1791
-
1792
- # Keep only the unprocessed tokens:
1793
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1794
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1795
- # input)
1796
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1797
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
1798
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1799
- # input_ids based on the past_length.
1800
- elif past_length < input_ids.shape[1]:
1801
- input_ids = input_ids[:, past_length:]
1802
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1803
-
1804
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1805
- if (
1806
- max_cache_length is not None
1807
- and attention_mask is not None
1808
- and cache_length + input_ids.shape[1] > max_cache_length
1809
- ):
1810
- attention_mask = attention_mask[:, -max_cache_length:]
1811
-
1812
- position_ids = kwargs.get("position_ids", None)
1813
- if attention_mask is not None and position_ids is None:
1814
- # create position_ids on the fly for batch generation
1815
- position_ids = attention_mask.long().cumsum(-1) - 1
1816
- position_ids.masked_fill_(attention_mask == 0, 1)
1817
- if past_key_values:
1818
- position_ids = position_ids[:, -input_ids.shape[1]:]
1819
-
1820
- if self.generation_config.cache_implementation == "static":
1821
- # generation with static cache
1822
- cache_position = kwargs.get("cache_position", None)
1823
- if cache_position is None:
1824
- past_length = 0
1825
- else:
1826
- past_length = cache_position[-1] + 1
1827
- input_ids = input_ids[:, past_length:]
1828
- position_ids = position_ids[:, past_length:]
1829
-
1830
- # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1831
- # same goes for position ids. Could also help with continued generation.
1832
- cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1833
-
1834
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1835
- if inputs_embeds is not None and past_key_values is None:
1836
- model_inputs = {"inputs_embeds": inputs_embeds}
1837
- else:
1838
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1839
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1840
- # TODO: use `next_tokens` directly instead.
1841
- model_inputs = {"input_ids": input_ids.contiguous()}
1842
-
1843
- model_inputs.update(
1844
- {
1845
- "position_ids": position_ids.contiguous(),
1846
- "cache_position": cache_position,
1847
- "past_key_values": past_key_values,
1848
- "use_cache": kwargs.get("use_cache"),
1849
- "attention_mask": attention_mask,
1850
- }
1851
- )
1852
- return model_inputs
1853
-
1854
- @staticmethod
1855
- def _reorder_cache(past_key_values, beam_idx):
1856
- reordered_past = ()
1857
- for layer_past in past_key_values:
1858
- reordered_past += (
1859
- tuple(
1860
- past_state.index_select(0, beam_idx.to(past_state.device))
1861
- for past_state in layer_past
1862
- ),
1863
- )
1864
- return reordered_past
1865
-
1866
-
1867
- @add_start_docstrings(
1868
- """
1869
- The DeepseekV2 Model transformer with a sequence classification head on top (linear layer).
1870
-
1871
- [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1872
- (e.g. GPT-2) do.
1873
-
1874
- Since it does classification on the last token, it requires to know the position of the last token. If a
1875
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1876
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1877
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1878
- each row of the batch).
1879
- """,
1880
- DeepseekV2_START_DOCSTRING,
1881
- )
1882
- class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1883
- def __init__(self, config):
1884
- super().__init__(config)
1885
- self.num_labels = config.num_labels
1886
- self.model = DeepseekV2Model(config)
1887
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1888
-
1889
- # Initialize weights and apply final processing
1890
- self.post_init()
1891
-
1892
- def get_input_embeddings(self):
1893
- return self.model.embed_tokens
1894
-
1895
- def set_input_embeddings(self, value):
1896
- self.model.embed_tokens = value
1897
-
1898
- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
1899
- def forward(
1900
- self,
1901
- input_ids: torch.LongTensor = None,
1902
- attention_mask: Optional[torch.Tensor] = None,
1903
- position_ids: Optional[torch.LongTensor] = None,
1904
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1905
- inputs_embeds: Optional[torch.FloatTensor] = None,
1906
- labels: Optional[torch.LongTensor] = None,
1907
- use_cache: Optional[bool] = None,
1908
- output_attentions: Optional[bool] = None,
1909
- output_hidden_states: Optional[bool] = None,
1910
- return_dict: Optional[bool] = None,
1911
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1912
- r"""
1913
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1914
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1915
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1916
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1917
- """
1918
- return_dict = (
1919
- return_dict if return_dict is not None else self.config.use_return_dict
1920
- )
1921
-
1922
- transformer_outputs = self.model(
1923
- input_ids,
1924
- attention_mask=attention_mask,
1925
- position_ids=position_ids,
1926
- past_key_values=past_key_values,
1927
- inputs_embeds=inputs_embeds,
1928
- use_cache=use_cache,
1929
- output_attentions=output_attentions,
1930
- output_hidden_states=output_hidden_states,
1931
- return_dict=return_dict,
1932
- )
1933
- hidden_states = transformer_outputs[0]
1934
- logits = self.score(hidden_states)
1935
-
1936
- if input_ids is not None:
1937
- batch_size = input_ids.shape[0]
1938
- else:
1939
- batch_size = inputs_embeds.shape[0]
1940
-
1941
- if self.config.pad_token_id is None and batch_size != 1:
1942
- raise ValueError(
1943
- "Cannot handle batch sizes > 1 if no padding token is defined."
1944
- )
1945
- if self.config.pad_token_id is None:
1946
- sequence_lengths = -1
1947
- else:
1948
- if input_ids is not None:
1949
- sequence_lengths = (
1950
- torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1951
- ).to(logits.device)
1952
- else:
1953
- sequence_lengths = -1
1954
-
1955
- pooled_logits = logits[
1956
- torch.arange(batch_size, device=logits.device), sequence_lengths
1957
- ]
1958
-
1959
- loss = None
1960
- if labels is not None:
1961
- labels = labels.to(logits.device)
1962
- if self.config.problem_type is None:
1963
- if self.num_labels == 1:
1964
- self.config.problem_type = "regression"
1965
- elif self.num_labels > 1 and (
1966
- labels.dtype == torch.long or labels.dtype == torch.int
1967
- ):
1968
- self.config.problem_type = "single_label_classification"
1969
- else:
1970
- self.config.problem_type = "multi_label_classification"
1971
-
1972
- if self.config.problem_type == "regression":
1973
- loss_fct = MSELoss()
1974
- if self.num_labels == 1:
1975
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1976
- else:
1977
- loss = loss_fct(pooled_logits, labels)
1978
- elif self.config.problem_type == "single_label_classification":
1979
- loss_fct = CrossEntropyLoss()
1980
- loss = loss_fct(
1981
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1982
- )
1983
- elif self.config.problem_type == "multi_label_classification":
1984
- loss_fct = BCEWithLogitsLoss()
1985
- loss = loss_fct(pooled_logits, labels)
1986
- if not return_dict:
1987
- output = (pooled_logits,) + transformer_outputs[1:]
1988
- return ((loss,) + output) if loss is not None else output
1989
-
1990
- return SequenceClassifierOutputWithPast(
1991
- loss=loss,
1992
- logits=pooled_logits,
1993
- past_key_values=transformer_outputs.past_key_values,
1994
- hidden_states=transformer_outputs.hidden_states,
1995
- attentions=transformer_outputs.attentions,
1996
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__pycache__/conversation.cpython-312.pyc DELETED
Binary file (10.5 kB)
 
__pycache__/deepencoder.cpython-312.pyc DELETED
Binary file (51.5 kB)
 
__pycache__/modeling_deepseekocr.cpython-312.pyc DELETED
Binary file (41.3 kB)