khazic commited on
Commit
f85774e
·
verified ·
1 Parent(s): d603a06

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. Unicorn/.DS_Store +0 -0
  3. Unicorn/bunny/.DS_Store +0 -0
  4. Unicorn/bunny/__init__.py +0 -0
  5. Unicorn/bunny/__pycache__/__init__.cpython-310.pyc +0 -0
  6. Unicorn/bunny/__pycache__/constants.cpython-310.pyc +0 -0
  7. Unicorn/bunny/__pycache__/conversation.cpython-310.pyc +0 -0
  8. Unicorn/bunny/constants.py +7 -0
  9. Unicorn/bunny/conversation.py +239 -0
  10. Unicorn/bunny/model/.DS_Store +0 -0
  11. Unicorn/bunny/model/__init__.py +6 -0
  12. Unicorn/bunny/model/__pycache__/__init__.cpython-310.pyc +0 -0
  13. Unicorn/bunny/model/__pycache__/bunny_arch.cpython-310.pyc +0 -0
  14. Unicorn/bunny/model/builder.py +49 -0
  15. Unicorn/bunny/model/bunny_arch.py +244 -0
  16. Unicorn/bunny/model/language_model/__init__.py +0 -0
  17. Unicorn/bunny/model/language_model/__pycache__/__init__.cpython-310.pyc +0 -0
  18. Unicorn/bunny/model/language_model/__pycache__/bunny_llama.cpython-310.pyc +0 -0
  19. Unicorn/bunny/model/language_model/__pycache__/bunny_minicpm.cpython-310.pyc +0 -0
  20. Unicorn/bunny/model/language_model/__pycache__/bunny_phi.cpython-310.pyc +0 -0
  21. Unicorn/bunny/model/language_model/__pycache__/bunny_phi3.cpython-310.pyc +0 -0
  22. Unicorn/bunny/model/language_model/__pycache__/bunny_qwen.cpython-310.pyc +0 -0
  23. Unicorn/bunny/model/language_model/__pycache__/bunny_stablelm.cpython-310.pyc +0 -0
  24. Unicorn/bunny/model/language_model/bunny_llama.py +103 -0
  25. Unicorn/bunny/model/language_model/bunny_minicpm.py +103 -0
  26. Unicorn/bunny/model/language_model/bunny_phi.py +100 -0
  27. Unicorn/bunny/model/language_model/bunny_phi3.py +100 -0
  28. Unicorn/bunny/model/language_model/bunny_qwen.py +100 -0
  29. Unicorn/bunny/model/language_model/bunny_stablelm.py +100 -0
  30. Unicorn/bunny/model/language_model/llama/__init__.py +114 -0
  31. Unicorn/bunny/model/language_model/llama/__pycache__/__init__.cpython-310.pyc +0 -0
  32. Unicorn/bunny/model/language_model/llama/__pycache__/configuration_llama.cpython-310.pyc +0 -0
  33. Unicorn/bunny/model/language_model/llama/__pycache__/modeling_llama.cpython-310.pyc +0 -0
  34. Unicorn/bunny/model/language_model/llama/configuration_llama.py +191 -0
  35. Unicorn/bunny/model/language_model/llama/modeling_llama.py +1844 -0
  36. Unicorn/bunny/model/language_model/llama/tokenization_llama.py +471 -0
  37. Unicorn/bunny/model/language_model/llama/tokenization_llama_fast.py +281 -0
  38. Unicorn/bunny/model/language_model/minicpm/__pycache__/configuration_minicpm.cpython-310.pyc +0 -0
  39. Unicorn/bunny/model/language_model/minicpm/__pycache__/modeling_minicpm.cpython-310.pyc +0 -0
  40. Unicorn/bunny/model/language_model/minicpm/configuration_minicpm.py +202 -0
  41. Unicorn/bunny/model/language_model/minicpm/modeling_minicpm.py +1456 -0
  42. Unicorn/bunny/model/language_model/phi/__init__.py +69 -0
  43. Unicorn/bunny/model/language_model/phi/__pycache__/__init__.cpython-310.pyc +0 -0
  44. Unicorn/bunny/model/language_model/phi/__pycache__/configuration_phi.cpython-310.pyc +0 -0
  45. Unicorn/bunny/model/language_model/phi/__pycache__/modeling_phi.cpython-310.pyc +0 -0
  46. Unicorn/bunny/model/language_model/phi/configuration_phi.py +195 -0
  47. Unicorn/bunny/model/language_model/phi/modeling_phi.py +1374 -0
  48. Unicorn/bunny/model/language_model/phi3/__init__.py +69 -0
  49. Unicorn/bunny/model/language_model/phi3/__pycache__/__init__.cpython-310.pyc +0 -0
  50. Unicorn/bunny/model/language_model/phi3/__pycache__/configuration_phi3.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Unicorn/wandb/run-20260113_224050-2hice92f/run-2hice92f.wandb filter=lfs diff=lfs merge=lfs -text
37
+ Unicorn/wandb/run-20260114_135552-sjoswxwz/run-sjoswxwz.wandb filter=lfs diff=lfs merge=lfs -text
38
+ Unicorn/wandb/run-20260114_170827-uobkoafb/run-uobkoafb.wandb filter=lfs diff=lfs merge=lfs -text
39
+ Unicorn/wandb/run-20260115_103501-4tsjsu0t/run-4tsjsu0t.wandb filter=lfs diff=lfs merge=lfs -text
40
+ Unicorn/wandb/run-20260115_230712-6c574jt7/run-6c574jt7.wandb filter=lfs diff=lfs merge=lfs -text
Unicorn/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Unicorn/bunny/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Unicorn/bunny/__init__.py ADDED
File without changes
Unicorn/bunny/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
Unicorn/bunny/__pycache__/constants.cpython-310.pyc ADDED
Binary file (347 Bytes). View file
 
Unicorn/bunny/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (5.73 kB). View file
 
Unicorn/bunny/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Model Constants
2
+ IGNORE_INDEX = -100
3
+ IMAGE_TOKEN_INDEX = -200
4
+ DEFAULT_IMAGE_TOKEN = "<image>"
5
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
6
+ LOGDIR = "gradio-logs"
7
+ WORKER_HEART_BEAT_INTERVAL = 15
Unicorn/bunny/conversation.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ TWO = auto()
9
+ PLAIN = auto()
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Conversation:
14
+ """A class that keeps all conversation history."""
15
+ system: str
16
+ roles: List[str]
17
+ messages: List[List[str]]
18
+ offset: int
19
+ sep_style: SeparatorStyle
20
+ sep: str = "###"
21
+ sep2: str = None
22
+ version: str = "Unknown"
23
+
24
+ skip_next: bool = False
25
+
26
+ def get_prompt(self):
27
+ messages = self.messages
28
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
29
+ messages = self.messages.copy()
30
+ init_role, init_msg = messages[0].copy()
31
+ init_msg = init_msg[0].replace("<image>", "").strip()
32
+ if 'mmtag' in self.version:
33
+ messages[0] = (init_role, init_msg)
34
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
35
+ messages.insert(1, (self.roles[1], "Received."))
36
+ else:
37
+ messages[0] = (init_role, "<image>\n" + init_msg)
38
+
39
+ if self.sep_style == SeparatorStyle.TWO:
40
+ seps = [self.sep, self.sep2]
41
+ ret = self.system + seps[0]
42
+ for i, (role, message) in enumerate(messages):
43
+ if message:
44
+ if type(message) is tuple:
45
+ message, _, _ = message
46
+ ret += role + ": " + message + seps[i % 2]
47
+ else:
48
+ ret += role + ":"
49
+
50
+ elif self.sep_style == SeparatorStyle.PLAIN:
51
+ seps = [self.sep, self.sep2]
52
+ ret = self.system
53
+ for i, (role, message) in enumerate(messages):
54
+ if message:
55
+ if type(message) is tuple:
56
+ message, _, _ = message
57
+ ret += message + seps[i % 2]
58
+ else:
59
+ ret += ""
60
+ else:
61
+ raise ValueError(f"Invalid style: {self.sep_style}")
62
+
63
+ return ret
64
+
65
+ def append_message(self, role, message):
66
+ self.messages.append([role, message])
67
+
68
+ def get_images(self, return_pil=False):
69
+ images = []
70
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
71
+ if i % 2 == 0:
72
+ if type(msg) is tuple:
73
+ import base64
74
+ from io import BytesIO
75
+ from PIL import Image
76
+ msg, image, image_process_mode = msg
77
+ if image_process_mode == "Pad":
78
+ def expand2square(pil_img, background_color=(122, 116, 104)):
79
+ width, height = pil_img.size
80
+ if width == height:
81
+ return pil_img
82
+ elif width > height:
83
+ result = Image.new(pil_img.mode, (width, width), background_color)
84
+ result.paste(pil_img, (0, (width - height) // 2))
85
+ return result
86
+ else:
87
+ result = Image.new(pil_img.mode, (height, height), background_color)
88
+ result.paste(pil_img, ((height - width) // 2, 0))
89
+ return result
90
+
91
+ image = expand2square(image)
92
+ elif image_process_mode in ["Default", "Crop"]:
93
+ pass
94
+ elif image_process_mode == "Resize":
95
+ image = image.resize((336, 336))
96
+ else:
97
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
98
+
99
+ if return_pil:
100
+ images.append(image)
101
+ else:
102
+ buffered = BytesIO()
103
+ image.save(buffered, format="PNG")
104
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
105
+ images.append(img_b64_str)
106
+ return images
107
+
108
+ def to_gradio_chatbot(self):
109
+ ret = []
110
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
111
+ if i % 2 == 0:
112
+ if type(msg) is tuple:
113
+ import base64
114
+ from io import BytesIO
115
+ msg, image, image_process_mode = msg
116
+ max_hw, min_hw = max(image.size), min(image.size)
117
+ aspect_ratio = max_hw / min_hw
118
+ max_len, min_len = 800, 400
119
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
120
+ longest_edge = int(shortest_edge * aspect_ratio)
121
+ W, H = image.size
122
+ if H > W:
123
+ H, W = longest_edge, shortest_edge
124
+ else:
125
+ H, W = shortest_edge, longest_edge
126
+ image = image.resize((W, H))
127
+ buffered = BytesIO()
128
+ image.save(buffered, format="JPEG")
129
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
130
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
131
+ msg = img_str + msg.replace('<image>', '').strip()
132
+ ret.append([msg, None])
133
+ else:
134
+ ret.append([msg, None])
135
+ else:
136
+ ret[-1][-1] = msg
137
+ return ret
138
+
139
+ def copy(self):
140
+ return Conversation(
141
+ system=self.system,
142
+ roles=self.roles,
143
+ messages=[[x, y] for x, y in self.messages],
144
+ offset=self.offset,
145
+ sep_style=self.sep_style,
146
+ sep=self.sep,
147
+ sep2=self.sep2,
148
+ version=self.version)
149
+
150
+ def dict(self):
151
+ if len(self.get_images()) > 0:
152
+ return {
153
+ "system": self.system,
154
+ "roles": self.roles,
155
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
156
+ "offset": self.offset,
157
+ "sep": self.sep,
158
+ "sep2": self.sep2,
159
+ }
160
+ return {
161
+ "system": self.system,
162
+ "roles": self.roles,
163
+ "messages": self.messages,
164
+ "offset": self.offset,
165
+ "sep": self.sep,
166
+ "sep2": self.sep2,
167
+ }
168
+
169
+
170
+ conv_bunny = Conversation(
171
+ system="A chat between a curious user and an artificial intelligence assistant. "
172
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
173
+ roles=("USER", "ASSISTANT"),
174
+ version="bunny",
175
+ messages=(),
176
+ offset=0,
177
+ sep_style=SeparatorStyle.TWO,
178
+ sep=" ",
179
+ sep2="<|endoftext|>",
180
+ )
181
+
182
+ conv_phi3 = Conversation(
183
+ system="A chat between a curious user and an artificial intelligence assistant. "
184
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
185
+ roles=("USER", "ASSISTANT"),
186
+ version="phi3",
187
+ messages=(),
188
+ offset=0,
189
+ sep_style=SeparatorStyle.TWO,
190
+ sep=" ",
191
+ sep2="<|endoftext|>",
192
+ )
193
+
194
+ conv_minicpm = Conversation(
195
+ system="A chat between a curious user and an artificial intelligence assistant. "
196
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
197
+ roles=("USER", "ASSISTANT"),
198
+ version="minicpm",
199
+ messages=(),
200
+ offset=0,
201
+ sep_style=SeparatorStyle.TWO,
202
+ sep=" ",
203
+ sep2="</s>",
204
+ )
205
+
206
+ conv_llama = Conversation(
207
+ system="A chat between a curious user and an artificial intelligence assistant. "
208
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
209
+ roles=("USER", "ASSISTANT"),
210
+ version="llama",
211
+ messages=(),
212
+ offset=0,
213
+ sep_style=SeparatorStyle.TWO,
214
+ sep=" ",
215
+ sep2="<|end_of_text|>",
216
+ )
217
+
218
+ conv_plain = Conversation(
219
+ system="",
220
+ roles=("", ""),
221
+ messages=(
222
+ ),
223
+ offset=0,
224
+ sep_style=SeparatorStyle.PLAIN,
225
+ sep="\n",
226
+ )
227
+
228
+ default_conversation = conv_bunny
229
+ conv_templates = {
230
+ "default": conv_bunny,
231
+ "bunny": conv_bunny,
232
+ "phi3": conv_phi3,
233
+ "plain": conv_plain,
234
+ 'minicpm': conv_minicpm,
235
+ 'llama': conv_llama
236
+ }
237
+
238
+ if __name__ == "__main__":
239
+ print(default_conversation.get_prompt())
Unicorn/bunny/model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Unicorn/bunny/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .language_model.bunny_phi import BunnyPhiForCausalLM, BunnyPhiConfig
2
+ from .language_model.bunny_stablelm import BunnyStableLMForCausalLM, BunnyStableLMConfig
3
+ from .language_model.bunny_qwen import BunnyQwen2ForCausalLM, BunnyQwen2Config
4
+ from .language_model.bunny_minicpm import BunnyMiniCPMForCausalLM, BunnyMiniCPMConfig
5
+ from .language_model.bunny_llama import BunnyLlamaForCausalLM, BunnyLlamaConfig
6
+ from .language_model.bunny_phi3 import BunnyPhi3ForCausalLM, BunnyPhi3Config
Unicorn/bunny/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (754 Bytes). View file
 
Unicorn/bunny/model/__pycache__/bunny_arch.cpython-310.pyc ADDED
Binary file (5.84 kB). View file
 
Unicorn/bunny/model/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import warnings
5
+ import transformers
6
+ # disable some warnings
7
+ transformers.logging.set_verbosity_error()
8
+ transformers.logging.disable_progress_bar()
9
+ warnings.filterwarnings('ignore')
10
+
11
+ import sys
12
+
13
+ # 把 /data/xmyu/Bunny_text/ 加进 sys.path,以便后续 import
14
+ sys.path.insert(0, "/data/xmyu/Bunny_text")
15
+ from bunny.model.language_model.bunny_llama import BunnyLlamaConfig, BunnyLlamaForCausalLM
16
+
17
+
18
+ def load_pretrained_model(model_path, model_base, model_name, model_type, load_8bit=False, load_4bit=False,
19
+ device_map="auto", device="cuda", **kwargs):
20
+
21
+ # Our Model
22
+ # model = AutoModelForCausalLM.from_pretrained(
23
+ # '/data/xmyu/finished-checkpoints/no-transfer/checkpoints-llama3-8b/bunny-llama3-8b',
24
+ # torch_dtype=torch.float16, # float32 for cpu
25
+ # trust_remote_code=True
26
+ # # device_map='auto'
27
+ # ).to("cuda")
28
+
29
+ # tokenizer = AutoTokenizer.from_pretrained(
30
+ # '/data/xmyu/finished-checkpoints/no-transfer/checkpoints-llama3-8b/bunny-llama3-8b',
31
+ # trust_remote_code=True
32
+ # )
33
+
34
+
35
+ # Our Model
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ '/data/xmyu/finished-checkpoints/mean_shift/checkpoints-llama3-8b/bunny-llama3-8b',
38
+ torch_dtype=torch.float16, # float32 for cpu
39
+ trust_remote_code=True
40
+ # device_map='auto'
41
+ ).to("cuda")
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(
44
+ '/data/xmyu/finished-checkpoints/mean_shift/checkpoints-llama3-8b/bunny-llama3-8b',
45
+ trust_remote_code=True
46
+ )
47
+
48
+
49
+ return tokenizer, model, 512
Unicorn/bunny/model/bunny_arch.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import os
3
+ import torch
4
+ from .multimodal_projector.builder import build_vision_projector
5
+
6
+ from bunny.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
7
+
8
+
9
+ class BunnyMetaModel:
10
+
11
+ def __init__(self, config):
12
+ super(BunnyMetaModel, self).__init__(config)
13
+
14
+ # 修改这里:不要使用 if True
15
+ # 使用 hasattr 检查配置中是否包含 mm_hidden_size。
16
+ # 1. 训练开始加载 Base Model 时,没有该属性,跳过构建(防止报错)。
17
+ # 后续 train.py 会调用 initialize_vision_modules 手动初始化它。
18
+ # 2. 推理加载训练好的 Bunny Model 时,Config 里有该属性,直接构建。
19
+
20
+ if hasattr(config, "mm_hidden_size"):
21
+ if getattr(config, 'continuous_training', False):
22
+ config.continuous_training = False
23
+ self.mm_projector = build_vision_projector(config)
24
+
25
+ def initialize_vision_modules(self, model_args):
26
+
27
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
28
+
29
+ self.config.use_mm_proj = True
30
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type')
31
+ self.config.mm_hidden_size = 1280
32
+
33
+ if getattr(self, 'mm_projector', None) is None:
34
+ self.mm_projector = build_vision_projector(self.config)
35
+ # else:
36
+ # In case it is frozen by LoRA
37
+ # for p in self.mm_projector.parameters():
38
+ # p.requires_grad = True
39
+
40
+ if pretrain_mm_mlp_adapter is not None:
41
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
42
+
43
+ def get_w(weights, keyword):
44
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
45
+
46
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
47
+
48
+
49
+ class BunnyMetaForCausalLM(ABC):
50
+
51
+ @abstractmethod
52
+ def get_model(self):
53
+ pass
54
+
55
+ def get_image_feature(self, embeds):
56
+
57
+ # 传给 projector 的 image feature 形状 [batch, 1280]
58
+
59
+
60
+ # print('<get_image_feature1------------------------------------------>')
61
+ # print(embeds)
62
+ # print('<get_image_feature1------------------------------------------>')
63
+
64
+ # 将 [batch, mm_hidden_size] 扩展为 [batch, seq, mm_hidden_size]
65
+ seq = 576
66
+ embeds = embeds.unsqueeze(1).expand(-1, seq, -1)
67
+
68
+ # embeds = self.mm_projector(embeds)
69
+
70
+ embeds = self.get_model().mm_projector(embeds)
71
+
72
+ # print('embeds2.shape', embeds.shape)
73
+
74
+ # print('<get_image_feature--------------------------------->')
75
+
76
+ return embeds # [batch, 1280]
77
+
78
+
79
+ def prepare_inputs_labels_for_multimodal(
80
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, embeds
81
+ ):
82
+
83
+ # print('<111111------------------------------------------>')
84
+ # print(embeds)
85
+ # print('<111111------------------------------------------>')
86
+
87
+ if embeds is None or input_ids.shape[1] == 1:
88
+ if past_key_values is not None and embeds is not None and input_ids.shape[
89
+ 1] == 1:
90
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
91
+ attention_mask = torch.cat((attention_mask, torch.ones(
92
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
93
+ dtype=attention_mask.dtype,
94
+ device=attention_mask.device
95
+ )), dim=1)
96
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
97
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
98
+
99
+
100
+
101
+ if embeds is not None:
102
+
103
+ # print('<In bunny arch------------------------------------>')
104
+ # print(embeds)
105
+ # print('<In bunny arch------------------------------------>')
106
+ # concat_images = torch.cat([image for image in images], dim=0)
107
+ image_features = self.get_image_feature(embeds) # [batch, 1280]
108
+
109
+ # print('<image_features!!!???---------------------->')
110
+ # print(image_features.shape)
111
+ # print('<image_features!!!???---------------------->')
112
+
113
+
114
+
115
+
116
+ # Let's just add dummy tensors if they do not exist,
117
+ # it is a headache to deal with None all the time.
118
+ # But it is not ideal, and if you have a better idea,
119
+ # please open an issue / submit a PR, thanks.
120
+ _labels = labels
121
+ _position_ids = position_ids
122
+ _attention_mask = attention_mask
123
+ if attention_mask is None:
124
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
125
+ else:
126
+ attention_mask = attention_mask.bool()
127
+ if position_ids is None:
128
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
129
+ if labels is None:
130
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
131
+
132
+ input_ids_temp = input_ids # points to the actual input_ids tensor
133
+
134
+ # remove the padding using attention_mask -- TODO: double check
135
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
136
+ zip(input_ids, attention_mask)]
137
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
138
+
139
+ # -- TODO: better implementation?
140
+ # replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
141
+ input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
142
+
143
+ new_input_embeds = []
144
+ new_labels = []
145
+ cur_image_idx = 0
146
+ for batch_idx, cur_input_ids in enumerate(input_ids):
147
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
148
+ if num_images == 0:
149
+ cur_image_features = image_features[cur_image_idx]
150
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
151
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
152
+ new_input_embeds.append(cur_input_embeds)
153
+ new_labels.append(labels[batch_idx])
154
+ cur_image_idx += 1
155
+ continue
156
+
157
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
158
+ cur_input_ids.shape[0]]
159
+ cur_input_ids_noim = []
160
+ cur_labels = labels[batch_idx]
161
+ cur_labels_noim = []
162
+ for i in range(len(image_token_indices) - 1):
163
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
164
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
165
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
166
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
167
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
168
+ cur_new_input_embeds = []
169
+ cur_new_labels = []
170
+
171
+ for i in range(num_images + 1):
172
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
173
+ cur_new_labels.append(cur_labels_noim[i])
174
+ if i < num_images:
175
+ cur_image_features = image_features[cur_image_idx]
176
+ cur_image_idx += 1
177
+ cur_new_input_embeds.append(cur_image_features)
178
+ cur_new_labels.append(
179
+ torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device,
180
+ dtype=cur_labels.dtype))
181
+
182
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
183
+ cur_new_labels = torch.cat(cur_new_labels)
184
+
185
+ new_input_embeds.append(cur_new_input_embeds)
186
+ new_labels.append(cur_new_labels)
187
+
188
+ # Truncate sequences to max length as image embeddings can make the sequence longer
189
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
190
+ if tokenizer_model_max_length is not None:
191
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
192
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
193
+
194
+ # Combine them
195
+ max_len = max(x.shape[0] for x in new_input_embeds)
196
+ batch_size = len(new_input_embeds)
197
+
198
+ new_input_embeds_padded = []
199
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype,
200
+ device=new_labels[0].device)
201
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
202
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
203
+
204
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
205
+ cur_len = cur_new_embed.shape[0]
206
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
207
+ new_input_embeds_padded.append(torch.cat((
208
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
209
+ device=cur_new_embed.device),
210
+ cur_new_embed
211
+ ), dim=0))
212
+ if cur_len > 0:
213
+ new_labels_padded[i, -cur_len:] = cur_new_labels
214
+ attention_mask[i, -cur_len:] = True
215
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype,
216
+ device=position_ids.device)
217
+ else:
218
+ new_input_embeds_padded.append(torch.cat((
219
+ cur_new_embed,
220
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
221
+ device=cur_new_embed.device)
222
+ ), dim=0))
223
+ if cur_len > 0:
224
+ new_labels_padded[i, :cur_len] = cur_new_labels
225
+ attention_mask[i, :cur_len] = True
226
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype,
227
+ device=position_ids.device)
228
+
229
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
230
+
231
+ if _labels is None:
232
+ new_labels = None
233
+ else:
234
+ new_labels = new_labels_padded
235
+
236
+ if _attention_mask is None:
237
+ attention_mask = None
238
+ else:
239
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
240
+
241
+ if _position_ids is None:
242
+ position_ids = None
243
+
244
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
Unicorn/bunny/model/language_model/__init__.py ADDED
File without changes
Unicorn/bunny/model/language_model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (170 Bytes). View file
 
Unicorn/bunny/model/language_model/__pycache__/bunny_llama.cpython-310.pyc ADDED
Binary file (3.15 kB). View file
 
Unicorn/bunny/model/language_model/__pycache__/bunny_minicpm.cpython-310.pyc ADDED
Binary file (3.25 kB). View file
 
Unicorn/bunny/model/language_model/__pycache__/bunny_phi.cpython-310.pyc ADDED
Binary file (3.04 kB). View file
 
Unicorn/bunny/model/language_model/__pycache__/bunny_phi3.cpython-310.pyc ADDED
Binary file (3.05 kB). View file
 
Unicorn/bunny/model/language_model/__pycache__/bunny_qwen.cpython-310.pyc ADDED
Binary file (3.06 kB). View file
 
Unicorn/bunny/model/language_model/__pycache__/bunny_stablelm.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
Unicorn/bunny/model/language_model/bunny_llama.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import os
3
+ import pickle
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoConfig, AutoModelForCausalLM
7
+
8
+ from .llama import LlamaModel, LlamaConfig, LlamaForCausalLM
9
+
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+
12
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
13
+
14
+
15
+ class BunnyLlamaConfig(LlamaConfig):
16
+ model_type = "bunny-llama"
17
+
18
+
19
+ class BunnyLlamaModel(BunnyMetaModel, LlamaModel):
20
+ config_class = BunnyLlamaConfig
21
+
22
+ def __init__(self, config: LlamaConfig):
23
+ super(BunnyLlamaModel, self).__init__(config)
24
+
25
+
26
+ class BunnyLlamaForCausalLM(LlamaForCausalLM, BunnyMetaForCausalLM):
27
+ config_class = BunnyLlamaConfig
28
+
29
+ def __init__(self, config):
30
+ super(LlamaForCausalLM, self).__init__(config)
31
+ self.model = BunnyLlamaModel(config)
32
+ self.vocab_size = config.vocab_size
33
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
+
35
+ # Initialize weights and apply final processing
36
+ self.post_init()
37
+
38
+ def get_model(self):
39
+ return self.model
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.LongTensor = None,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ position_ids: Optional[torch.LongTensor] = None,
46
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
47
+ inputs_embeds: Optional[torch.FloatTensor] = None,
48
+ labels: Optional[torch.LongTensor] = None,
49
+ use_cache: Optional[bool] = None,
50
+ output_attentions: Optional[bool] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ embeds: Optional[list] = None,
53
+ return_dict: Optional[bool] = None,
54
+ cache_position: Optional[torch.LongTensor] = None,
55
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
56
+ if inputs_embeds is None:
57
+ (
58
+ input_ids,
59
+ position_ids,
60
+ attention_mask,
61
+ past_key_values,
62
+ inputs_embeds,
63
+ labels
64
+ ) = self.prepare_inputs_labels_for_multimodal(
65
+ input_ids,
66
+ position_ids,
67
+ attention_mask,
68
+ past_key_values,
69
+ labels,
70
+ embeds
71
+ )
72
+
73
+ return super().forward(
74
+ input_ids=input_ids,
75
+ attention_mask=attention_mask,
76
+ position_ids=position_ids,
77
+ past_key_values=past_key_values,
78
+ inputs_embeds=inputs_embeds,
79
+ labels=labels,
80
+ use_cache=use_cache,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ cache_position=None
85
+ )
86
+
87
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
88
+ **kwargs):
89
+ embeds = kwargs.pop("embeds", None)
90
+
91
+ _inputs = super().prepare_inputs_for_generation(
92
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
93
+ **kwargs
94
+ )
95
+
96
+ if embeds is not None:
97
+ _inputs['embeds'] = embeds
98
+
99
+ return _inputs
100
+
101
+
102
+ AutoConfig.register("bunny-llama", BunnyLlamaConfig)
103
+ AutoModelForCausalLM.register(BunnyLlamaConfig, BunnyLlamaForCausalLM)
Unicorn/bunny/model/language_model/bunny_minicpm.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from bunny.model.language_model.minicpm.modeling_minicpm import MiniCPMModel, MiniCPMForCausalLM
8
+ from bunny.model.language_model.minicpm.configuration_minicpm import MiniCPMConfig
9
+
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+
12
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
13
+
14
+
15
+ class BunnyMiniCPMConfig(MiniCPMConfig):
16
+ model_type = "bunny-minicpm"
17
+
18
+
19
+ class BunnyMiniCPMModel(BunnyMetaModel, MiniCPMModel):
20
+ config_class = BunnyMiniCPMConfig
21
+
22
+ def __init__(self, config: MiniCPMConfig):
23
+ super(BunnyMiniCPMModel, self).__init__(config)
24
+
25
+
26
+ class BunnyMiniCPMForCausalLM(MiniCPMForCausalLM, BunnyMetaForCausalLM):
27
+ config_class = BunnyMiniCPMConfig
28
+
29
+ def __init__(self, config):
30
+ super(MiniCPMForCausalLM, self).__init__(config)
31
+ self.model = BunnyMiniCPMModel(config)
32
+ self.vocab_size = config.vocab_size
33
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
+
35
+ # Initialize weights and apply final processing
36
+ self.post_init()
37
+
38
+ def get_model(self):
39
+ return self.model
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.LongTensor = None,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ position_ids: Optional[torch.LongTensor] = None,
46
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
47
+ inputs_embeds: Optional[torch.FloatTensor] = None,
48
+ labels: Optional[torch.LongTensor] = None,
49
+ use_cache: Optional[bool] = None,
50
+ output_attentions: Optional[bool] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ images: Optional[torch.FloatTensor] = None,
53
+ return_dict: Optional[bool] = None,
54
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
55
+
56
+ if inputs_embeds is None:
57
+ (
58
+ input_ids,
59
+ position_ids,
60
+ attention_mask,
61
+ past_key_values,
62
+ inputs_embeds,
63
+ labels
64
+ ) = self.prepare_inputs_labels_for_multimodal(
65
+ input_ids,
66
+ position_ids,
67
+ attention_mask,
68
+ past_key_values,
69
+ labels,
70
+ images
71
+ )
72
+ if inputs_embeds is not None:
73
+ inputs_embeds *= self.get_model().config.scale_emb
74
+
75
+ return super().forward(
76
+ input_ids=input_ids,
77
+ attention_mask=attention_mask,
78
+ position_ids=position_ids,
79
+ past_key_values=past_key_values,
80
+ inputs_embeds=inputs_embeds,
81
+ labels=labels,
82
+ use_cache=use_cache,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict
86
+ )
87
+
88
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
89
+ **kwargs):
90
+ images = kwargs.pop("images", None)
91
+
92
+ _inputs = super().prepare_inputs_for_generation(
93
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
94
+ **kwargs
95
+ )
96
+
97
+ if images is not None:
98
+ _inputs['images'] = images
99
+ return _inputs
100
+
101
+
102
+ AutoConfig.register("bunny-minicpm", BunnyMiniCPMConfig)
103
+ AutoModelForCausalLM.register(BunnyMiniCPMConfig, BunnyMiniCPMForCausalLM)
Unicorn/bunny/model/language_model/bunny_phi.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from .phi import PhiModel, PhiConfig, PhiForCausalLM
8
+
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
+
13
+
14
+ class BunnyPhiConfig(PhiConfig):
15
+ model_type = "bunny-phi"
16
+
17
+
18
+ class BunnyPhiModel(BunnyMetaModel, PhiModel):
19
+ config_class = BunnyPhiConfig
20
+
21
+ def __init__(self, config: PhiConfig):
22
+ super(BunnyPhiModel, self).__init__(config)
23
+
24
+
25
+ class BunnyPhiForCausalLM(PhiForCausalLM, BunnyMetaForCausalLM):
26
+ config_class = BunnyPhiConfig
27
+
28
+ def __init__(self, config):
29
+ super(PhiForCausalLM, self).__init__(config)
30
+ self.model = BunnyPhiModel(config)
31
+ self.vocab_size = config.vocab_size
32
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
+
34
+ # Initialize weights and apply final processing
35
+ self.post_init()
36
+
37
+ def get_model(self):
38
+ return self.model
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: torch.LongTensor = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ position_ids: Optional[torch.LongTensor] = None,
45
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
46
+ inputs_embeds: Optional[torch.FloatTensor] = None,
47
+ labels: Optional[torch.LongTensor] = None,
48
+ use_cache: Optional[bool] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ images: Optional[torch.FloatTensor] = None,
52
+ return_dict: Optional[bool] = None,
53
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
54
+
55
+ if inputs_embeds is None:
56
+ (
57
+ input_ids,
58
+ position_ids,
59
+ attention_mask,
60
+ past_key_values,
61
+ inputs_embeds,
62
+ labels
63
+ ) = self.prepare_inputs_labels_for_multimodal(
64
+ input_ids,
65
+ position_ids,
66
+ attention_mask,
67
+ past_key_values,
68
+ labels,
69
+ images
70
+ )
71
+
72
+ return super().forward(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ labels=labels,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict
83
+ )
84
+
85
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
+ **kwargs):
87
+ images = kwargs.pop("images", None)
88
+
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
+ **kwargs
92
+ )
93
+
94
+ if images is not None:
95
+ _inputs['images'] = images
96
+ return _inputs
97
+
98
+
99
+ AutoConfig.register("bunny-phi", BunnyPhiConfig)
100
+ AutoModelForCausalLM.register(BunnyPhiConfig, BunnyPhiForCausalLM)
Unicorn/bunny/model/language_model/bunny_phi3.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from .phi3 import Phi3Model, Phi3Config, Phi3ForCausalLM
8
+
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
+
13
+
14
+ class BunnyPhi3Config(Phi3Config):
15
+ model_type = "bunny-phi3"
16
+
17
+
18
+ class BunnyPhi3Model(BunnyMetaModel, Phi3Model):
19
+ config_class = BunnyPhi3Config
20
+
21
+ def __init__(self, config: Phi3Config):
22
+ super(BunnyPhi3Model, self).__init__(config)
23
+
24
+
25
+ class BunnyPhi3ForCausalLM(Phi3ForCausalLM, BunnyMetaForCausalLM):
26
+ config_class = BunnyPhi3Config
27
+
28
+ def __init__(self, config):
29
+ super(Phi3ForCausalLM, self).__init__(config)
30
+ self.model = BunnyPhi3Model(config)
31
+ self.vocab_size = config.vocab_size
32
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
+
34
+ # Initialize weights and apply final processing
35
+ self.post_init()
36
+
37
+ def get_model(self):
38
+ return self.model
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: torch.LongTensor = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ position_ids: Optional[torch.LongTensor] = None,
45
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
46
+ inputs_embeds: Optional[torch.FloatTensor] = None,
47
+ labels: Optional[torch.LongTensor] = None,
48
+ use_cache: Optional[bool] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ images: Optional[torch.FloatTensor] = None,
52
+ return_dict: Optional[bool] = None,
53
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
54
+
55
+ if inputs_embeds is None:
56
+ (
57
+ input_ids,
58
+ position_ids,
59
+ attention_mask,
60
+ past_key_values,
61
+ inputs_embeds,
62
+ labels
63
+ ) = self.prepare_inputs_labels_for_multimodal(
64
+ input_ids,
65
+ position_ids,
66
+ attention_mask,
67
+ past_key_values,
68
+ labels,
69
+ images
70
+ )
71
+
72
+ return super().forward(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ labels=labels,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict
83
+ )
84
+
85
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
+ **kwargs):
87
+ images = kwargs.pop("images", None)
88
+
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
+ **kwargs
92
+ )
93
+
94
+ if images is not None:
95
+ _inputs['images'] = images
96
+ return _inputs
97
+
98
+
99
+ AutoConfig.register("bunny-phi3", BunnyPhi3Config)
100
+ AutoModelForCausalLM.register(BunnyPhi3Config, BunnyPhi3ForCausalLM)
Unicorn/bunny/model/language_model/bunny_qwen.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from .qwen2 import Qwen2Model, Qwen2Config, Qwen2ForCausalLM
8
+
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
+
13
+
14
+ class BunnyQwen2Config(Qwen2Config):
15
+ model_type = "bunny-qwen2"
16
+
17
+
18
+ class BunnyQwen2Model(BunnyMetaModel, Qwen2Model):
19
+ config_class = BunnyQwen2Config
20
+
21
+ def __init__(self, config: Qwen2Config):
22
+ super(BunnyQwen2Model, self).__init__(config)
23
+
24
+
25
+ class BunnyQwen2ForCausalLM(Qwen2ForCausalLM, BunnyMetaForCausalLM):
26
+ config_class = BunnyQwen2Config
27
+
28
+ def __init__(self, config):
29
+ super(Qwen2ForCausalLM, self).__init__(config)
30
+ self.model = BunnyQwen2Model(config)
31
+ self.vocab_size = config.vocab_size
32
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
+
34
+ # Initialize weights and apply final processing
35
+ self.post_init()
36
+
37
+ def get_model(self):
38
+ return self.model
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: torch.LongTensor = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ position_ids: Optional[torch.LongTensor] = None,
45
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
46
+ inputs_embeds: Optional[torch.FloatTensor] = None,
47
+ labels: Optional[torch.LongTensor] = None,
48
+ use_cache: Optional[bool] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ images: Optional[torch.FloatTensor] = None,
52
+ return_dict: Optional[bool] = None,
53
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
54
+
55
+ if inputs_embeds is None:
56
+ (
57
+ input_ids,
58
+ position_ids,
59
+ attention_mask,
60
+ past_key_values,
61
+ inputs_embeds,
62
+ labels
63
+ ) = self.prepare_inputs_labels_for_multimodal(
64
+ input_ids,
65
+ position_ids,
66
+ attention_mask,
67
+ past_key_values,
68
+ labels,
69
+ images
70
+ )
71
+
72
+ return super().forward(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ labels=labels,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict
83
+ )
84
+
85
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
+ **kwargs):
87
+ images = kwargs.pop("images", None)
88
+
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
+ **kwargs
92
+ )
93
+
94
+ if images is not None:
95
+ _inputs['images'] = images
96
+ return _inputs
97
+
98
+
99
+ AutoConfig.register("bunny-qwen2", BunnyQwen2Config)
100
+ AutoModelForCausalLM.register(BunnyQwen2Config, BunnyQwen2ForCausalLM)
Unicorn/bunny/model/language_model/bunny_stablelm.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from bunny.model.language_model.stable_lm.modeling_stablelm_epoch import StableLMEpochModel, StableLMEpochConfig, \
8
+ StableLMEpochForCausalLM
9
+
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+
12
+ from bunny.model.bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
13
+
14
+
15
+ class BunnyStableLMConfig(StableLMEpochConfig):
16
+ model_type = "bunny-stablelm"
17
+
18
+
19
+ class BunnyStableLMModel(BunnyMetaModel, StableLMEpochModel):
20
+ config_class = BunnyStableLMConfig
21
+
22
+ def __init__(self, config: StableLMEpochConfig):
23
+ super(BunnyStableLMModel, self).__init__(config)
24
+
25
+
26
+ class BunnyStableLMForCausalLM(StableLMEpochForCausalLM, BunnyMetaForCausalLM):
27
+ config_class = BunnyStableLMConfig
28
+
29
+ def __init__(self, config):
30
+ super(StableLMEpochForCausalLM, self).__init__(config)
31
+ self.model = BunnyStableLMModel(config)
32
+ self.vocab_size = config.vocab_size
33
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
+
35
+ # Initialize weights and apply final processing
36
+ self.post_init()
37
+
38
+ def get_model(self):
39
+ return self.model
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.LongTensor = None,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ position_ids: Optional[torch.LongTensor] = None,
46
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
47
+ inputs_embeds: Optional[torch.FloatTensor] = None,
48
+ labels: Optional[torch.LongTensor] = None,
49
+ use_cache: Optional[bool] = None,
50
+ output_attentions: Optional[bool] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ images: Optional[torch.FloatTensor] = None,
53
+ return_dict: Optional[bool] = None,
54
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
55
+ if inputs_embeds is None:
56
+ (
57
+ input_ids,
58
+ position_ids,
59
+ attention_mask,
60
+ past_key_values,
61
+ inputs_embeds,
62
+ labels
63
+ ) = self.prepare_inputs_labels_for_multimodal(
64
+ input_ids,
65
+ position_ids,
66
+ attention_mask,
67
+ past_key_values,
68
+ labels,
69
+ images
70
+ )
71
+
72
+ return super().forward(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ labels=labels,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict
83
+ )
84
+
85
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
+ **kwargs):
87
+ images = kwargs.pop("images", None)
88
+
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
+ **kwargs
92
+ )
93
+
94
+ if images is not None:
95
+ _inputs['images'] = images
96
+ return _inputs
97
+
98
+
99
+ AutoConfig.register("bunny-stablelm", BunnyStableLMConfig)
100
+ AutoModelForCausalLM.register(BunnyStableLMConfig, BunnyStableLMForCausalLM)
Unicorn/bunny/model/language_model/llama/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import (
17
+ OptionalDependencyNotAvailable,
18
+ _LazyModule,
19
+ is_flax_available,
20
+ is_sentencepiece_available,
21
+ is_tokenizers_available,
22
+ is_torch_available,
23
+ )
24
+
25
+
26
+ _import_structure = {
27
+ "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"],
28
+ }
29
+
30
+ try:
31
+ if not is_sentencepiece_available():
32
+ raise OptionalDependencyNotAvailable()
33
+ except OptionalDependencyNotAvailable:
34
+ pass
35
+ else:
36
+ _import_structure["tokenization_llama"] = ["LlamaTokenizer"]
37
+
38
+ try:
39
+ if not is_tokenizers_available():
40
+ raise OptionalDependencyNotAvailable()
41
+ except OptionalDependencyNotAvailable:
42
+ pass
43
+ else:
44
+ _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]
45
+
46
+ try:
47
+ if not is_torch_available():
48
+ raise OptionalDependencyNotAvailable()
49
+ except OptionalDependencyNotAvailable:
50
+ pass
51
+ else:
52
+ _import_structure["modeling_llama"] = [
53
+ "LlamaForCausalLM",
54
+ "LlamaModel",
55
+ "LlamaPreTrainedModel",
56
+ "LlamaForSequenceClassification",
57
+ "LlamaForQuestionAnswering",
58
+ ]
59
+
60
+ try:
61
+ if not is_flax_available():
62
+ raise OptionalDependencyNotAvailable()
63
+ except OptionalDependencyNotAvailable:
64
+ pass
65
+ else:
66
+ _import_structure["modeling_flax_llama"] = ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]
67
+
68
+
69
+ if TYPE_CHECKING:
70
+ from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
71
+
72
+ try:
73
+ if not is_sentencepiece_available():
74
+ raise OptionalDependencyNotAvailable()
75
+ except OptionalDependencyNotAvailable:
76
+ pass
77
+ else:
78
+ from .tokenization_llama import LlamaTokenizer
79
+
80
+ try:
81
+ if not is_tokenizers_available():
82
+ raise OptionalDependencyNotAvailable()
83
+ except OptionalDependencyNotAvailable:
84
+ pass
85
+ else:
86
+ from .tokenization_llama_fast import LlamaTokenizerFast
87
+
88
+ try:
89
+ if not is_torch_available():
90
+ raise OptionalDependencyNotAvailable()
91
+ except OptionalDependencyNotAvailable:
92
+ pass
93
+ else:
94
+ from .modeling_llama import (
95
+ LlamaForCausalLM,
96
+ LlamaForQuestionAnswering,
97
+ LlamaForSequenceClassification,
98
+ LlamaModel,
99
+ LlamaPreTrainedModel,
100
+ )
101
+
102
+ try:
103
+ if not is_flax_available():
104
+ raise OptionalDependencyNotAvailable()
105
+ except OptionalDependencyNotAvailable:
106
+ pass
107
+ else:
108
+ from .modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel, FlaxLlamaPreTrainedModel
109
+
110
+
111
+ else:
112
+ import sys
113
+
114
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Unicorn/bunny/model/language_model/llama/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.66 kB). View file
 
Unicorn/bunny/model/language_model/llama/__pycache__/configuration_llama.cpython-310.pyc ADDED
Binary file (7.78 kB). View file
 
Unicorn/bunny/model/language_model/llama/__pycache__/modeling_llama.cpython-310.pyc ADDED
Binary file (55.3 kB). View file
 
Unicorn/bunny/model/language_model/llama/configuration_llama.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ LLaMA model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ # from ..deprecated._archive_maps import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
30
+
31
+
32
+ class LlamaConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of the LLaMA-7B.
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+
42
+ Args:
43
+ vocab_size (`int`, *optional*, defaults to 32000):
44
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
45
+ `inputs_ids` passed when calling [`LlamaModel`]
46
+ hidden_size (`int`, *optional*, defaults to 4096):
47
+ Dimension of the hidden representations.
48
+ intermediate_size (`int`, *optional*, defaults to 11008):
49
+ Dimension of the MLP representations.
50
+ num_hidden_layers (`int`, *optional*, defaults to 32):
51
+ Number of hidden layers in the Transformer decoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 32):
53
+ Number of attention heads for each attention layer in the Transformer decoder.
54
+ num_key_value_heads (`int`, *optional*):
55
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
+ by meanpooling all the original heads within that group. For more details checkout [this
60
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
61
+ `num_attention_heads`.
62
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
63
+ The non-linear activation function (function or string) in the decoder.
64
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
65
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
66
+ Llama 2 up to 4096, CodeLlama up to 16384.
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
70
+ The epsilon used by the rms normalization layers.
71
+ use_cache (`bool`, *optional*, defaults to `True`):
72
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
73
+ relevant if `config.is_decoder=True`.
74
+ pad_token_id (`int`, *optional*):
75
+ Padding token id.
76
+ bos_token_id (`int`, *optional*, defaults to 1):
77
+ Beginning of stream token id.
78
+ eos_token_id (`int`, *optional*, defaults to 2):
79
+ End of stream token id.
80
+ pretraining_tp (`int`, *optional*, defaults to 1):
81
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
82
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
83
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
84
+ issue](https://github.com/pytorch/pytorch/issues/76232).
85
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
86
+ Whether to tie weight embeddings
87
+ rope_theta (`float`, *optional*, defaults to 10000.0):
88
+ The base period of the RoPE embeddings.
89
+ rope_scaling (`Dict`, *optional*):
90
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
91
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
92
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
93
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
94
+ these scaling strategies behave:
95
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
96
+ experimental feature, subject to breaking API changes in future versions.
97
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
98
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
99
+ attention_dropout (`float`, *optional*, defaults to 0.0):
100
+ The dropout ratio for the attention probabilities.
101
+
102
+ ```python
103
+ >>> from transformers import LlamaModel, LlamaConfig
104
+
105
+ >>> # Initializing a LLaMA llama-7b style configuration
106
+ >>> configuration = LlamaConfig()
107
+
108
+ >>> # Initializing a model from the llama-7b style configuration
109
+ >>> model = LlamaModel(configuration)
110
+
111
+ >>> # Accessing the model configuration
112
+ >>> configuration = model.config
113
+ ```"""
114
+
115
+ model_type = "llama"
116
+ keys_to_ignore_at_inference = ["past_key_values"]
117
+
118
+ def __init__(
119
+ self,
120
+ vocab_size=32000,
121
+ hidden_size=4096,
122
+ intermediate_size=11008,
123
+ num_hidden_layers=32,
124
+ num_attention_heads=32,
125
+ num_key_value_heads=None,
126
+ hidden_act="silu",
127
+ max_position_embeddings=2048,
128
+ initializer_range=0.02,
129
+ rms_norm_eps=1e-6,
130
+ use_cache=True,
131
+ pad_token_id=None,
132
+ bos_token_id=1,
133
+ eos_token_id=2,
134
+ pretraining_tp=1,
135
+ tie_word_embeddings=False,
136
+ rope_theta=10000.0,
137
+ rope_scaling=None,
138
+ attention_bias=False,
139
+ attention_dropout=0.0,
140
+ **kwargs,
141
+ ):
142
+ self.vocab_size = vocab_size
143
+ self.max_position_embeddings = max_position_embeddings
144
+ self.hidden_size = hidden_size
145
+ self.intermediate_size = intermediate_size
146
+ self.num_hidden_layers = num_hidden_layers
147
+ self.num_attention_heads = num_attention_heads
148
+
149
+ # for backward compatibility
150
+ if num_key_value_heads is None:
151
+ num_key_value_heads = num_attention_heads
152
+
153
+ self.num_key_value_heads = num_key_value_heads
154
+ self.hidden_act = hidden_act
155
+ self.initializer_range = initializer_range
156
+ self.rms_norm_eps = rms_norm_eps
157
+ self.pretraining_tp = pretraining_tp
158
+ self.use_cache = use_cache
159
+ self.rope_theta = rope_theta
160
+ self.rope_scaling = rope_scaling
161
+ self._rope_scaling_validation()
162
+ self.attention_bias = attention_bias
163
+ self.attention_dropout = attention_dropout
164
+
165
+ super().__init__(
166
+ pad_token_id=pad_token_id,
167
+ bos_token_id=bos_token_id,
168
+ eos_token_id=eos_token_id,
169
+ tie_word_embeddings=tie_word_embeddings,
170
+ **kwargs,
171
+ )
172
+
173
+ def _rope_scaling_validation(self):
174
+ """
175
+ Validate the `rope_scaling` configuration.
176
+ """
177
+ if self.rope_scaling is None:
178
+ return
179
+
180
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
181
+ raise ValueError(
182
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
183
+ )
184
+ rope_scaling_type = self.rope_scaling.get("type", None)
185
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
186
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
187
+ raise ValueError(
188
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
189
+ )
190
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
191
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
Unicorn/bunny/model/language_model/llama/modeling_llama.py ADDED
@@ -0,0 +1,1844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch LLaMA model."""
21
+
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
+ # from transformers.modeling_attn_mask_utils import AttentionMaskConverter
35
+ from dataclasses import dataclass
36
+ @dataclass
37
+ class AttentionMaskConverter:
38
+ """
39
+ A utility attention mask class that allows one to:
40
+ - Create a causal 4d mask
41
+ - Create a causal 4d mask with slided window
42
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
43
+ key_value_length) that can be multiplied with attention scores
44
+
45
+ Examples:
46
+
47
+ ```python
48
+ >>> import torch
49
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
50
+
51
+ >>> converter = AttentionMaskConverter(True)
52
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
53
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
54
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
55
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
56
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
57
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
58
+ ```
59
+
60
+ Parameters:
61
+ is_causal (`bool`):
62
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
63
+
64
+ sliding_window (`int`, *optional*):
65
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
66
+ """
67
+
68
+ is_causal: bool
69
+ sliding_window: int
70
+
71
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
72
+ self.is_causal = is_causal
73
+ self.sliding_window = sliding_window
74
+
75
+ if self.sliding_window is not None and self.sliding_window <= 0:
76
+ raise ValueError(
77
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
78
+ )
79
+
80
+ def to_causal_4d(
81
+ self,
82
+ batch_size: int,
83
+ query_length: int,
84
+ key_value_length: int,
85
+ dtype: torch.dtype,
86
+ device: Union[torch.device, "str"] = "cpu",
87
+ ) -> Optional[torch.Tensor]:
88
+ """
89
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
90
+ bias to upper right hand triangular matrix (causal mask).
91
+ """
92
+ if not self.is_causal:
93
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
94
+
95
+ # If shape is not cached, create a new causal mask and cache it
96
+ input_shape = (batch_size, query_length)
97
+ past_key_values_length = key_value_length - query_length
98
+
99
+ # create causal mask
100
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
101
+ causal_4d_mask = None
102
+ if input_shape[-1] > 1 or self.sliding_window is not None:
103
+ causal_4d_mask = self._make_causal_mask(
104
+ input_shape,
105
+ dtype,
106
+ device=device,
107
+ past_key_values_length=past_key_values_length,
108
+ sliding_window=self.sliding_window,
109
+ )
110
+
111
+ return causal_4d_mask
112
+
113
+ def to_4d(
114
+ self,
115
+ attention_mask_2d: torch.Tensor,
116
+ query_length: int,
117
+ dtype: torch.dtype,
118
+ key_value_length: Optional[int] = None,
119
+ ) -> torch.Tensor:
120
+ """
121
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
122
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
123
+ causal, a causal mask will be added.
124
+ """
125
+ input_shape = (attention_mask_2d.shape[0], query_length)
126
+
127
+ # create causal mask
128
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
129
+ causal_4d_mask = None
130
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
131
+ if key_value_length is None:
132
+ raise ValueError(
133
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
134
+ )
135
+
136
+ past_key_values_length = key_value_length - query_length
137
+ causal_4d_mask = self._make_causal_mask(
138
+ input_shape,
139
+ dtype,
140
+ device=attention_mask_2d.device,
141
+ past_key_values_length=past_key_values_length,
142
+ sliding_window=self.sliding_window,
143
+ )
144
+ elif self.sliding_window is not None:
145
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
146
+
147
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
148
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
149
+ attention_mask_2d.device
150
+ )
151
+
152
+ if causal_4d_mask is not None:
153
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
154
+
155
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
156
+ expanded_4d_mask = expanded_attn_mask
157
+
158
+ return expanded_4d_mask
159
+
160
+ @staticmethod
161
+ def _make_causal_mask(
162
+ input_ids_shape: torch.Size,
163
+ dtype: torch.dtype,
164
+ device: torch.device,
165
+ past_key_values_length: int = 0,
166
+ sliding_window: Optional[int] = None,
167
+ ):
168
+ """
169
+ Make causal mask used for bi-directional self-attention.
170
+ """
171
+ bsz, tgt_len = input_ids_shape
172
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
173
+ mask_cond = torch.arange(mask.size(-1), device=device)
174
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
175
+
176
+ mask = mask.to(dtype)
177
+
178
+ if past_key_values_length > 0:
179
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
180
+
181
+ # add lower triangular sliding window mask if necessary
182
+ if sliding_window is not None:
183
+ diagonal = past_key_values_length - sliding_window - 1
184
+
185
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
186
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
187
+
188
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
189
+
190
+ @staticmethod
191
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
192
+ """
193
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
194
+ """
195
+ bsz, src_len = mask.size()
196
+ tgt_len = tgt_len if tgt_len is not None else src_len
197
+
198
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
199
+
200
+ inverted_mask = 1.0 - expanded_mask
201
+
202
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
203
+
204
+ @staticmethod
205
+ def _unmask_unattended(
206
+ expanded_mask: torch.FloatTensor,
207
+ min_dtype: float,
208
+ ):
209
+ # fmt: off
210
+ """
211
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
212
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
213
+ Details: https://github.com/pytorch/pytorch/issues/110213
214
+
215
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
216
+ `attention_mask` is [bsz, src_seq_len].
217
+
218
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
219
+
220
+ For example, if `expanded_mask` is (e.g. here left-padding case)
221
+ ```
222
+ [[[[0, 0, 0],
223
+ [0, 0, 0],
224
+ [0, 0, 1]]],
225
+ [[[1, 0, 0],
226
+ [1, 1, 0],
227
+ [1, 1, 1]]],
228
+ [[[0, 0, 0],
229
+ [0, 1, 0],
230
+ [0, 1, 1]]]]
231
+ ```
232
+ then the modified `expanded_mask` will be
233
+ ```
234
+ [[[[1, 1, 1], <-- modified
235
+ [1, 1, 1], <-- modified
236
+ [0, 0, 1]]],
237
+ [[[1, 0, 0],
238
+ [1, 1, 0],
239
+ [1, 1, 1]]],
240
+ [[[1, 1, 1], <-- modified
241
+ [0, 1, 0],
242
+ [0, 1, 1]]]]
243
+ ```
244
+ """
245
+ # fmt: on
246
+ if expanded_mask.dtype == torch.bool:
247
+ raise ValueError(
248
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
249
+ )
250
+
251
+ return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
252
+
253
+ @staticmethod
254
+ def _ignore_causal_mask_sdpa(
255
+ attention_mask: Optional[torch.Tensor],
256
+ inputs_embeds: torch.Tensor,
257
+ past_key_values_length: int,
258
+ sliding_window: Optional[int] = None,
259
+ ) -> bool:
260
+ """
261
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
262
+
263
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
264
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
265
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
266
+ """
267
+
268
+ batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
269
+ key_value_length = query_length + past_key_values_length
270
+
271
+ is_tracing = (
272
+ torch.jit.is_tracing()
273
+ or isinstance(inputs_embeds, torch.fx.Proxy)
274
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
275
+ )
276
+
277
+ ignore_causal_mask = False
278
+
279
+ if attention_mask is None:
280
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
281
+ # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
282
+ # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag.
283
+ #
284
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
285
+ if (
286
+ not is_tracing
287
+ and (query_length == 1 or key_value_length == query_length)
288
+ and (sliding_window is None or key_value_length < sliding_window)
289
+ ):
290
+ ignore_causal_mask = True
291
+ elif sliding_window is None or key_value_length < sliding_window:
292
+ if len(attention_mask.shape) == 4:
293
+ expected_shape = (batch_size, 1, query_length, key_value_length)
294
+ if tuple(attention_mask.shape) != expected_shape:
295
+ raise ValueError(
296
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
297
+ )
298
+ elif not is_tracing and torch.all(attention_mask == 1):
299
+ if query_length == 1 or key_value_length == query_length:
300
+ # For query_length == 1, causal attention and bi-directional attention are the same.
301
+ ignore_causal_mask = True
302
+
303
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
304
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
305
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
306
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
307
+
308
+ return ignore_causal_mask
309
+
310
+
311
+ from transformers.modeling_outputs import (
312
+ BaseModelOutputWithPast,
313
+ CausalLMOutputWithPast,
314
+ QuestionAnsweringModelOutput,
315
+ SequenceClassifierOutputWithPast,
316
+ )
317
+ from transformers.modeling_utils import PreTrainedModel
318
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
319
+ from transformers.utils import (
320
+ add_start_docstrings,
321
+ add_start_docstrings_to_model_forward,
322
+ is_flash_attn_2_available,
323
+ is_flash_attn_greater_or_equal_2_10,
324
+ logging,
325
+ replace_return_docstrings,
326
+ )
327
+ from .configuration_llama import LlamaConfig
328
+
329
+
330
+ if is_flash_attn_2_available():
331
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
332
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
333
+
334
+
335
+ logger = logging.get_logger(__name__)
336
+
337
+ _CONFIG_FOR_DOC = "LlamaConfig"
338
+
339
+
340
+ def _get_unpad_data(attention_mask):
341
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
342
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
343
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
344
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
345
+ return (
346
+ indices,
347
+ cu_seqlens,
348
+ max_seqlen_in_batch,
349
+ )
350
+
351
+
352
+ class LlamaRMSNorm(nn.Module):
353
+ def __init__(self, hidden_size, eps=1e-6):
354
+ """
355
+ LlamaRMSNorm is equivalent to T5LayerNorm
356
+ """
357
+ super().__init__()
358
+ self.weight = nn.Parameter(torch.ones(hidden_size))
359
+ self.variance_epsilon = eps
360
+
361
+ def forward(self, hidden_states):
362
+ input_dtype = hidden_states.dtype
363
+ hidden_states = hidden_states.to(torch.float32)
364
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
365
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
366
+ return self.weight * hidden_states.to(input_dtype)
367
+
368
+
369
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
370
+
371
+
372
+ class LlamaRotaryEmbedding(nn.Module):
373
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
374
+ super().__init__()
375
+ self.scaling_factor = scaling_factor
376
+ self.dim = dim
377
+ self.max_position_embeddings = max_position_embeddings
378
+ self.base = base
379
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
380
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
381
+ # For BC we register cos and sin cached
382
+ self.max_seq_len_cached = max_position_embeddings
383
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
384
+ t = t / self.scaling_factor
385
+ freqs = torch.outer(t, self.inv_freq)
386
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
387
+ emb = torch.cat((freqs, freqs), dim=-1)
388
+ self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
389
+ self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
390
+
391
+ @property
392
+ def sin_cached(self):
393
+ logger.warning_once(
394
+ "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
395
+ "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
396
+ )
397
+ return self._sin_cached
398
+
399
+ @property
400
+ def cos_cached(self):
401
+ logger.warning_once(
402
+ "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
403
+ "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
404
+ )
405
+ return self._cos_cached
406
+
407
+ @torch.no_grad()
408
+ def forward(self, x, position_ids):
409
+ # x: [bs, num_attention_heads, seq_len, head_size]
410
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
411
+ position_ids_expanded = position_ids[:, None, :].float()
412
+ # Force float32 since bfloat16 loses precision on long contexts
413
+ # See https://github.com/huggingface/transformers/pull/29285
414
+ device_type = x.device.type
415
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
416
+ with torch.autocast(device_type=device_type, enabled=False):
417
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
418
+ emb = torch.cat((freqs, freqs), dim=-1)
419
+ cos = emb.cos()
420
+ sin = emb.sin()
421
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
422
+
423
+
424
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
425
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
426
+
427
+ def forward(self, x, position_ids):
428
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
429
+ position_ids = position_ids.float() / self.scaling_factor
430
+ cos, sin = super().forward(x, position_ids)
431
+ return cos, sin
432
+
433
+
434
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
435
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
436
+
437
+ def forward(self, x, position_ids):
438
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
439
+ seq_len = torch.max(position_ids) + 1
440
+ if seq_len > self.max_position_embeddings:
441
+ base = self.base * (
442
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
443
+ ) ** (self.dim / (self.dim - 2))
444
+ inv_freq = 1.0 / (
445
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
446
+ )
447
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
448
+
449
+ cos, sin = super().forward(x, position_ids)
450
+ return cos, sin
451
+
452
+
453
+ def rotate_half(x):
454
+ """Rotates half the hidden dims of the input."""
455
+ x1 = x[..., : x.shape[-1] // 2]
456
+ x2 = x[..., x.shape[-1] // 2 :]
457
+ return torch.cat((-x2, x1), dim=-1)
458
+
459
+
460
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
461
+ """Applies Rotary Position Embedding to the query and key tensors.
462
+
463
+ Args:
464
+ q (`torch.Tensor`): The query tensor.
465
+ k (`torch.Tensor`): The key tensor.
466
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
467
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
468
+ position_ids (`torch.Tensor`, *optional*):
469
+ Deprecated and unused.
470
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
471
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
472
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
473
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
474
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
475
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
476
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
477
+ Returns:
478
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
479
+ """
480
+ cos = cos.unsqueeze(unsqueeze_dim)
481
+ sin = sin.unsqueeze(unsqueeze_dim)
482
+ q_embed = (q * cos) + (rotate_half(q) * sin)
483
+ k_embed = (k * cos) + (rotate_half(k) * sin)
484
+ return q_embed, k_embed
485
+
486
+
487
+ class LlamaMLP(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.config = config
491
+ self.hidden_size = config.hidden_size
492
+ self.intermediate_size = config.intermediate_size
493
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
494
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
495
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
496
+ self.act_fn = ACT2FN[config.hidden_act]
497
+
498
+ def forward(self, x):
499
+ if self.config.pretraining_tp > 1:
500
+ slice = self.intermediate_size // self.config.pretraining_tp
501
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
502
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
503
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
504
+
505
+ gate_proj = torch.cat(
506
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
507
+ )
508
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
509
+
510
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
511
+ down_proj = [
512
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
513
+ ]
514
+ down_proj = sum(down_proj)
515
+ else:
516
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
517
+
518
+ return down_proj
519
+
520
+
521
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
522
+ """
523
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
524
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
525
+ """
526
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
527
+ if n_rep == 1:
528
+ return hidden_states
529
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
530
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
531
+
532
+
533
+ class LlamaAttention(nn.Module):
534
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
535
+
536
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
537
+ super().__init__()
538
+ self.config = config
539
+ self.layer_idx = layer_idx
540
+ if layer_idx is None:
541
+ logger.warning_once(
542
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
543
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
544
+ "when creating this class."
545
+ )
546
+
547
+ self.attention_dropout = config.attention_dropout
548
+ self.hidden_size = config.hidden_size
549
+ self.num_heads = config.num_attention_heads
550
+ self.head_dim = self.hidden_size // self.num_heads
551
+ self.num_key_value_heads = config.num_key_value_heads
552
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
553
+ self.max_position_embeddings = config.max_position_embeddings
554
+ self.rope_theta = config.rope_theta
555
+ self.is_causal = True
556
+
557
+ if (self.head_dim * self.num_heads) != self.hidden_size:
558
+ raise ValueError(
559
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
560
+ f" and `num_heads`: {self.num_heads})."
561
+ )
562
+
563
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
564
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
565
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
566
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
567
+ self._init_rope()
568
+
569
+ def _init_rope(self):
570
+ if self.config.rope_scaling is None:
571
+ self.rotary_emb = LlamaRotaryEmbedding(
572
+ self.head_dim,
573
+ max_position_embeddings=self.max_position_embeddings,
574
+ base=self.rope_theta,
575
+ )
576
+ else:
577
+ scaling_type = self.config.rope_scaling["type"]
578
+ scaling_factor = self.config.rope_scaling["factor"]
579
+ if scaling_type == "linear":
580
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
581
+ self.head_dim,
582
+ max_position_embeddings=self.max_position_embeddings,
583
+ scaling_factor=scaling_factor,
584
+ base=self.rope_theta,
585
+ )
586
+ elif scaling_type == "dynamic":
587
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
588
+ self.head_dim,
589
+ max_position_embeddings=self.max_position_embeddings,
590
+ scaling_factor=scaling_factor,
591
+ base=self.rope_theta,
592
+ )
593
+ else:
594
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
595
+
596
+ def forward(
597
+ self,
598
+ hidden_states: torch.Tensor,
599
+ attention_mask: Optional[torch.Tensor] = None,
600
+ position_ids: Optional[torch.LongTensor] = None,
601
+ past_key_value: Optional[Cache] = None,
602
+ output_attentions: bool = False,
603
+ use_cache: bool = False,
604
+ cache_position: Optional[torch.LongTensor] = None,
605
+ **kwargs,
606
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
607
+ bsz, q_len, _ = hidden_states.size()
608
+
609
+ if self.config.pretraining_tp > 1:
610
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
611
+ query_slices = self.q_proj.weight.split(
612
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
613
+ )
614
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
615
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
616
+
617
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
618
+ query_states = torch.cat(query_states, dim=-1)
619
+
620
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
621
+ key_states = torch.cat(key_states, dim=-1)
622
+
623
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
624
+ value_states = torch.cat(value_states, dim=-1)
625
+
626
+ else:
627
+ query_states = self.q_proj(hidden_states)
628
+ key_states = self.k_proj(hidden_states)
629
+ value_states = self.v_proj(hidden_states)
630
+
631
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
632
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
633
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
634
+
635
+ past_key_value = getattr(self, "past_key_value", past_key_value)
636
+ cos, sin = self.rotary_emb(value_states, position_ids)
637
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
638
+
639
+ if past_key_value is not None:
640
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
641
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
642
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
643
+
644
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
645
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
646
+
647
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
648
+
649
+ if attention_mask is not None: # no matter the length, we just slice it
650
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
651
+ attn_weights = attn_weights + causal_mask
652
+
653
+ # upcast attention to fp32
654
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
655
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
656
+ attn_output = torch.matmul(attn_weights, value_states)
657
+
658
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
659
+ raise ValueError(
660
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
661
+ f" {attn_output.size()}"
662
+ )
663
+
664
+ attn_output = attn_output.transpose(1, 2).contiguous()
665
+
666
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
667
+
668
+ if self.config.pretraining_tp > 1:
669
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
670
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
671
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
672
+ else:
673
+ attn_output = self.o_proj(attn_output)
674
+
675
+ if not output_attentions:
676
+ attn_weights = None
677
+
678
+ return attn_output, attn_weights, past_key_value
679
+
680
+
681
+ class LlamaFlashAttention2(LlamaAttention):
682
+ """
683
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
684
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
685
+ flash attention and deal with padding tokens in case the input contains any of them.
686
+ """
687
+
688
+ def __init__(self, *args, **kwargs):
689
+ super().__init__(*args, **kwargs)
690
+
691
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
692
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
693
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
694
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
695
+
696
+ def forward(
697
+ self,
698
+ hidden_states: torch.Tensor,
699
+ attention_mask: Optional[torch.LongTensor] = None,
700
+ position_ids: Optional[torch.LongTensor] = None,
701
+ past_key_value: Optional[Cache] = None,
702
+ output_attentions: bool = False,
703
+ use_cache: bool = False,
704
+ cache_position: Optional[torch.LongTensor] = None,
705
+ **kwargs,
706
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
707
+ output_attentions = False
708
+
709
+ bsz, q_len, _ = hidden_states.size()
710
+
711
+ query_states = self.q_proj(hidden_states)
712
+ key_states = self.k_proj(hidden_states)
713
+ value_states = self.v_proj(hidden_states)
714
+
715
+ # Flash attention requires the input to have the shape
716
+ # batch_size x seq_length x head_dim x hidden_dim
717
+ # therefore we just need to keep the original shape
718
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
719
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
720
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
721
+
722
+ cos, sin = self.rotary_emb(value_states, position_ids)
723
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
724
+
725
+ past_key_value = getattr(self, "past_key_value", past_key_value)
726
+
727
+ if past_key_value is not None:
728
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
729
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
730
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
731
+
732
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
733
+ # to be able to avoid many of these transpose/reshape/view.
734
+ query_states = query_states.transpose(1, 2)
735
+ key_states = key_states.transpose(1, 2)
736
+ value_states = value_states.transpose(1, 2)
737
+
738
+ dropout_rate = self.attention_dropout if self.training else 0.0
739
+
740
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
741
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
742
+ # cast them back in the correct dtype just to be sure everything works as expected.
743
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
744
+ # in fp32. (LlamaRMSNorm handles it correctly)
745
+
746
+ input_dtype = query_states.dtype
747
+ if input_dtype == torch.float32:
748
+ if torch.is_autocast_enabled():
749
+ target_dtype = torch.get_autocast_gpu_dtype()
750
+ # Handle the case where the model is quantized
751
+ elif hasattr(self.config, "_pre_quantization_dtype"):
752
+ target_dtype = self.config._pre_quantization_dtype
753
+ else:
754
+ target_dtype = self.q_proj.weight.dtype
755
+
756
+ logger.warning_once(
757
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
758
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
759
+ f" {target_dtype}."
760
+ )
761
+
762
+ query_states = query_states.to(target_dtype)
763
+ key_states = key_states.to(target_dtype)
764
+ value_states = value_states.to(target_dtype)
765
+
766
+ attn_output = self._flash_attention_forward(
767
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
768
+ )
769
+
770
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
771
+ attn_output = self.o_proj(attn_output)
772
+
773
+ if not output_attentions:
774
+ attn_weights = None
775
+
776
+ return attn_output, attn_weights, past_key_value
777
+
778
+ def _flash_attention_forward(
779
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
780
+ ):
781
+ """
782
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
783
+ first unpad the input, then computes the attention scores and pad the final attention scores.
784
+
785
+ Args:
786
+ query_states (`torch.Tensor`):
787
+ Input query states to be passed to Flash Attention API
788
+ key_states (`torch.Tensor`):
789
+ Input key states to be passed to Flash Attention API
790
+ value_states (`torch.Tensor`):
791
+ Input value states to be passed to Flash Attention API
792
+ attention_mask (`torch.Tensor`):
793
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
794
+ position of padding tokens and 1 for the position of non-padding tokens.
795
+ dropout (`float`):
796
+ Attention dropout
797
+ softmax_scale (`float`, *optional*):
798
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
799
+ """
800
+ if not self._flash_attn_uses_top_left_mask:
801
+ causal = self.is_causal
802
+ else:
803
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
804
+ causal = self.is_causal and query_length != 1
805
+
806
+ # Contains at least one padding token in the sequence
807
+ if attention_mask is not None:
808
+ batch_size = query_states.shape[0]
809
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
810
+ query_states, key_states, value_states, attention_mask, query_length
811
+ )
812
+
813
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
814
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
815
+
816
+ attn_output_unpad = flash_attn_varlen_func(
817
+ query_states,
818
+ key_states,
819
+ value_states,
820
+ cu_seqlens_q=cu_seqlens_q,
821
+ cu_seqlens_k=cu_seqlens_k,
822
+ max_seqlen_q=max_seqlen_in_batch_q,
823
+ max_seqlen_k=max_seqlen_in_batch_k,
824
+ dropout_p=dropout,
825
+ softmax_scale=softmax_scale,
826
+ causal=causal,
827
+ )
828
+
829
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
830
+ else:
831
+ attn_output = flash_attn_func(
832
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
833
+ )
834
+
835
+ return attn_output
836
+
837
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
838
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
839
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
840
+
841
+ key_layer = index_first_axis(
842
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
843
+ )
844
+ value_layer = index_first_axis(
845
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
846
+ )
847
+ if query_length == kv_seq_len:
848
+ query_layer = index_first_axis(
849
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
850
+ )
851
+ cu_seqlens_q = cu_seqlens_k
852
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
853
+ indices_q = indices_k
854
+ elif query_length == 1:
855
+ max_seqlen_in_batch_q = 1
856
+ cu_seqlens_q = torch.arange(
857
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
858
+ ) # There is a memcpy here, that is very bad.
859
+ indices_q = cu_seqlens_q[:-1]
860
+ query_layer = query_layer.squeeze(1)
861
+ else:
862
+ # The -q_len: slice assumes left padding.
863
+ attention_mask = attention_mask[:, -query_length:]
864
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
865
+
866
+ return (
867
+ query_layer,
868
+ key_layer,
869
+ value_layer,
870
+ indices_q,
871
+ (cu_seqlens_q, cu_seqlens_k),
872
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
873
+ )
874
+
875
+
876
+ class LlamaSdpaAttention(LlamaAttention):
877
+ """
878
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
879
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
880
+ SDPA API.
881
+ """
882
+
883
+ # Adapted from LlamaAttention.forward
884
+ def forward(
885
+ self,
886
+ hidden_states: torch.Tensor,
887
+ attention_mask: Optional[torch.Tensor] = None,
888
+ position_ids: Optional[torch.LongTensor] = None,
889
+ past_key_value: Optional[Cache] = None,
890
+ output_attentions: bool = False,
891
+ use_cache: bool = False,
892
+ cache_position: Optional[torch.LongTensor] = None,
893
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
894
+ if output_attentions:
895
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
896
+ logger.warning_once(
897
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
898
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
899
+ )
900
+ return super().forward(
901
+ hidden_states=hidden_states,
902
+ attention_mask=attention_mask,
903
+ position_ids=position_ids,
904
+ past_key_value=past_key_value,
905
+ output_attentions=output_attentions,
906
+ use_cache=use_cache,
907
+ cache_position=cache_position,
908
+ )
909
+
910
+ bsz, q_len, _ = hidden_states.size()
911
+
912
+ query_states = self.q_proj(hidden_states)
913
+ key_states = self.k_proj(hidden_states)
914
+ value_states = self.v_proj(hidden_states)
915
+
916
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
917
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
918
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
919
+
920
+ cos, sin = self.rotary_emb(value_states, position_ids)
921
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
922
+
923
+ # In case static cache is used, it is an instance attribute.
924
+ past_key_value = getattr(self, "past_key_value", past_key_value)
925
+
926
+ if past_key_value is not None:
927
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
928
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
929
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
930
+
931
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
932
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
933
+
934
+ causal_mask = attention_mask
935
+ if attention_mask is not None:
936
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
937
+
938
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
939
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
940
+ if query_states.device.type == "cuda" and causal_mask is not None:
941
+ query_states = query_states.contiguous()
942
+ key_states = key_states.contiguous()
943
+ value_states = value_states.contiguous()
944
+
945
+ # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
946
+ # relying on the `is_causal` argument.
947
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
948
+ query_states,
949
+ key_states,
950
+ value_states,
951
+ attn_mask=causal_mask,
952
+ dropout_p=self.attention_dropout if self.training else 0.0,
953
+ is_causal=causal_mask is None and q_len > 1,
954
+ )
955
+
956
+ attn_output = attn_output.transpose(1, 2).contiguous()
957
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
958
+
959
+ attn_output = self.o_proj(attn_output)
960
+
961
+ return attn_output, None, past_key_value
962
+
963
+
964
+ LLAMA_ATTENTION_CLASSES = {
965
+ "eager": LlamaAttention,
966
+ "flash_attention_2": LlamaFlashAttention2,
967
+ "sdpa": LlamaSdpaAttention,
968
+ }
969
+
970
+
971
+ class LlamaDecoderLayer(nn.Module):
972
+ def __init__(self, config: LlamaConfig, layer_idx: int):
973
+ super().__init__()
974
+ self.hidden_size = config.hidden_size
975
+
976
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
977
+
978
+ self.mlp = LlamaMLP(config)
979
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
980
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
981
+
982
+ def forward(
983
+ self,
984
+ hidden_states: torch.Tensor,
985
+ attention_mask: Optional[torch.Tensor] = None,
986
+ position_ids: Optional[torch.LongTensor] = None,
987
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
988
+ output_attentions: Optional[bool] = False,
989
+ use_cache: Optional[bool] = False,
990
+ cache_position: Optional[torch.LongTensor] = None,
991
+ **kwargs,
992
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
993
+ """
994
+ Args:
995
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
996
+ attention_mask (`torch.FloatTensor`, *optional*):
997
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
998
+ query_sequence_length, key_sequence_length)` if default attention is used.
999
+ output_attentions (`bool`, *optional*):
1000
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1001
+ returned tensors for more detail.
1002
+ use_cache (`bool`, *optional*):
1003
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1004
+ (see `past_key_values`).
1005
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1006
+ """
1007
+ if "padding_mask" in kwargs:
1008
+ warnings.warn(
1009
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1010
+ )
1011
+
1012
+ residual = hidden_states
1013
+
1014
+ hidden_states = self.input_layernorm(hidden_states)
1015
+
1016
+ # Self Attention
1017
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1018
+ hidden_states=hidden_states,
1019
+ attention_mask=attention_mask,
1020
+ position_ids=position_ids,
1021
+ past_key_value=past_key_value,
1022
+ output_attentions=output_attentions,
1023
+ use_cache=use_cache,
1024
+ cache_position=cache_position,
1025
+ **kwargs,
1026
+ )
1027
+ hidden_states = residual + hidden_states
1028
+
1029
+ # Fully Connected
1030
+ residual = hidden_states
1031
+ hidden_states = self.post_attention_layernorm(hidden_states)
1032
+ hidden_states = self.mlp(hidden_states)
1033
+ hidden_states = residual + hidden_states
1034
+
1035
+ outputs = (hidden_states,)
1036
+
1037
+ if output_attentions:
1038
+ outputs += (self_attn_weights,)
1039
+
1040
+ if use_cache:
1041
+ outputs += (present_key_value,)
1042
+
1043
+ return outputs
1044
+
1045
+
1046
+ LLAMA_START_DOCSTRING = r"""
1047
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1048
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1049
+ etc.)
1050
+
1051
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1052
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1053
+ and behavior.
1054
+
1055
+ Parameters:
1056
+ config ([`LlamaConfig`]):
1057
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1058
+ load the weights associated with the model, only the configuration. Check out the
1059
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1060
+ """
1061
+
1062
+
1063
+ @add_start_docstrings(
1064
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1065
+ LLAMA_START_DOCSTRING,
1066
+ )
1067
+ class LlamaPreTrainedModel(PreTrainedModel):
1068
+ config_class = LlamaConfig
1069
+ base_model_prefix = "model"
1070
+ supports_gradient_checkpointing = True
1071
+ _no_split_modules = ["LlamaDecoderLayer"]
1072
+ _skip_keys_device_placement = ["past_key_values"]
1073
+ _supports_flash_attn_2 = True
1074
+ _supports_sdpa = True
1075
+ _supports_cache_class = True
1076
+
1077
+ def _init_weights(self, module):
1078
+ std = self.config.initializer_range
1079
+ if isinstance(module, nn.Linear):
1080
+ module.weight.data.normal_(mean=0.0, std=std)
1081
+ if module.bias is not None:
1082
+ module.bias.data.zero_()
1083
+ elif isinstance(module, nn.Embedding):
1084
+ module.weight.data.normal_(mean=0.0, std=std)
1085
+ if module.padding_idx is not None:
1086
+ module.weight.data[module.padding_idx].zero_()
1087
+
1088
+ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
1089
+ if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
1090
+ raise ValueError(
1091
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
1092
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
1093
+ )
1094
+
1095
+ for layer in self.model.layers:
1096
+ device = layer.input_layernorm.weight.device
1097
+ if hasattr(self.config, "_pre_quantization_dtype"):
1098
+ dtype = self.config._pre_quantization_dtype
1099
+ else:
1100
+ dtype = layer.self_attn.o_proj.weight.dtype
1101
+ layer.self_attn.past_key_value = cache_cls(
1102
+ self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
1103
+ )
1104
+
1105
+ def _reset_cache(self):
1106
+ for layer in self.model.layers:
1107
+ layer.self_attn.past_key_value = None
1108
+
1109
+
1110
+ LLAMA_INPUTS_DOCSTRING = r"""
1111
+ Args:
1112
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1113
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1114
+ it.
1115
+
1116
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1117
+ [`PreTrainedTokenizer.__call__`] for details.
1118
+
1119
+ [What are input IDs?](../glossary#input-ids)
1120
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1121
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1122
+
1123
+ - 1 for tokens that are **not masked**,
1124
+ - 0 for tokens that are **masked**.
1125
+
1126
+ [What are attention masks?](../glossary#attention-mask)
1127
+
1128
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1129
+ [`PreTrainedTokenizer.__call__`] for details.
1130
+
1131
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1132
+ `past_key_values`).
1133
+
1134
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1135
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1136
+ information on the default strategy.
1137
+
1138
+ - 1 indicates the head is **not masked**,
1139
+ - 0 indicates the head is **masked**.
1140
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1141
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1142
+ config.n_positions - 1]`.
1143
+
1144
+ [What are position IDs?](../glossary#position-ids)
1145
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1146
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1147
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1148
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1149
+
1150
+ Two formats are allowed:
1151
+ - a [`~cache_utils.Cache`] instance;
1152
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1153
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1154
+ cache format.
1155
+
1156
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1157
+ legacy cache format will be returned.
1158
+
1159
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1160
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1161
+ of shape `(batch_size, sequence_length)`.
1162
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1163
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1164
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1165
+ model's internal embedding lookup matrix.
1166
+ use_cache (`bool`, *optional*):
1167
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1168
+ `past_key_values`).
1169
+ output_attentions (`bool`, *optional*):
1170
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1171
+ tensors for more detail.
1172
+ output_hidden_states (`bool`, *optional*):
1173
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1174
+ more detail.
1175
+ return_dict (`bool`, *optional*):
1176
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1177
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1178
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1179
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1180
+ the complete sequence length.
1181
+ """
1182
+
1183
+
1184
+ @add_start_docstrings(
1185
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1186
+ LLAMA_START_DOCSTRING,
1187
+ )
1188
+ class LlamaModel(LlamaPreTrainedModel):
1189
+ """
1190
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
1191
+
1192
+ Args:
1193
+ config: LlamaConfig
1194
+ """
1195
+
1196
+ def __init__(self, config: LlamaConfig):
1197
+ super().__init__(config)
1198
+ self.padding_idx = config.pad_token_id
1199
+ self.vocab_size = config.vocab_size
1200
+
1201
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1202
+ self.layers = nn.ModuleList(
1203
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1204
+ )
1205
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1206
+ self.gradient_checkpointing = False
1207
+
1208
+ # Initialize weights and apply final processing
1209
+ self.post_init()
1210
+
1211
+ def get_input_embeddings(self):
1212
+ return self.embed_tokens
1213
+
1214
+ def set_input_embeddings(self, value):
1215
+ self.embed_tokens = value
1216
+
1217
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1218
+ def forward(
1219
+ self,
1220
+ input_ids: torch.LongTensor = None,
1221
+ attention_mask: Optional[torch.Tensor] = None,
1222
+ position_ids: Optional[torch.LongTensor] = None,
1223
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1224
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1225
+ use_cache: Optional[bool] = None,
1226
+ output_attentions: Optional[bool] = None,
1227
+ output_hidden_states: Optional[bool] = None,
1228
+ return_dict: Optional[bool] = None,
1229
+ cache_position: Optional[torch.LongTensor] = None,
1230
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1231
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1232
+ output_hidden_states = (
1233
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1234
+ )
1235
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1237
+
1238
+ if (input_ids is None) ^ (inputs_embeds is not None):
1239
+ raise ValueError(
1240
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1241
+ )
1242
+
1243
+ if self.gradient_checkpointing and self.training and use_cache:
1244
+ logger.warning_once(
1245
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1246
+ )
1247
+ use_cache = False
1248
+
1249
+ if inputs_embeds is None:
1250
+ inputs_embeds = self.embed_tokens(input_ids)
1251
+
1252
+ past_seen_tokens = 0
1253
+ if use_cache: # kept for BC (cache positions)
1254
+ if not isinstance(past_key_values, StaticCache):
1255
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1256
+ past_seen_tokens = past_key_values.get_seq_length()
1257
+
1258
+ if cache_position is None:
1259
+ if isinstance(past_key_values, StaticCache):
1260
+ raise ValueError("cache_position is a required argument when using StaticCache.")
1261
+ cache_position = torch.arange(
1262
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1263
+ )
1264
+
1265
+ if position_ids is None:
1266
+ position_ids = cache_position.unsqueeze(0)
1267
+
1268
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
1269
+
1270
+ # embed positions
1271
+ hidden_states = inputs_embeds
1272
+
1273
+ # decoder layers
1274
+ all_hidden_states = () if output_hidden_states else None
1275
+ all_self_attns = () if output_attentions else None
1276
+ next_decoder_cache = None
1277
+
1278
+ for decoder_layer in self.layers:
1279
+ if output_hidden_states:
1280
+ all_hidden_states += (hidden_states,)
1281
+
1282
+ if self.gradient_checkpointing and self.training:
1283
+ layer_outputs = self._gradient_checkpointing_func(
1284
+ decoder_layer.__call__,
1285
+ hidden_states,
1286
+ causal_mask,
1287
+ position_ids,
1288
+ past_key_values,
1289
+ output_attentions,
1290
+ use_cache,
1291
+ cache_position,
1292
+ )
1293
+ else:
1294
+ layer_outputs = decoder_layer(
1295
+ hidden_states,
1296
+ attention_mask=causal_mask,
1297
+ position_ids=position_ids,
1298
+ past_key_value=past_key_values,
1299
+ output_attentions=output_attentions,
1300
+ use_cache=use_cache,
1301
+ cache_position=cache_position,
1302
+ )
1303
+
1304
+ hidden_states = layer_outputs[0]
1305
+
1306
+ if use_cache:
1307
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1308
+
1309
+ if output_attentions:
1310
+ all_self_attns += (layer_outputs[1],)
1311
+
1312
+ hidden_states = self.norm(hidden_states)
1313
+
1314
+ # add hidden states from the last decoder layer
1315
+ if output_hidden_states:
1316
+ all_hidden_states += (hidden_states,)
1317
+
1318
+ next_cache = None
1319
+ if use_cache:
1320
+ next_cache = (
1321
+ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1322
+ )
1323
+ if not return_dict:
1324
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1325
+ return BaseModelOutputWithPast(
1326
+ last_hidden_state=hidden_states,
1327
+ past_key_values=next_cache,
1328
+ hidden_states=all_hidden_states,
1329
+ attentions=all_self_attns,
1330
+ )
1331
+
1332
+ def _update_causal_mask(
1333
+ self,
1334
+ attention_mask: torch.Tensor,
1335
+ input_tensor: torch.Tensor,
1336
+ cache_position: torch.Tensor,
1337
+ past_seen_tokens: int,
1338
+ ):
1339
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1340
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1341
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1342
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1343
+
1344
+ if self.config._attn_implementation == "flash_attention_2":
1345
+ if attention_mask is not None and 0.0 in attention_mask:
1346
+ return attention_mask
1347
+ return None
1348
+
1349
+ if self.config._attn_implementation == "sdpa":
1350
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
1351
+ # in order to dispatch on Flash Attention 2.
1352
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1353
+ attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
1354
+ ):
1355
+ return None
1356
+
1357
+ dtype, device = input_tensor.dtype, input_tensor.device
1358
+ min_dtype = torch.finfo(dtype).min
1359
+ sequence_length = input_tensor.shape[1]
1360
+ if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
1361
+ target_length = self.config.max_position_embeddings
1362
+ else: # dynamic cache
1363
+ target_length = (
1364
+ attention_mask.shape[-1]
1365
+ if isinstance(attention_mask, torch.Tensor)
1366
+ else past_seen_tokens + sequence_length + 1
1367
+ )
1368
+
1369
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1370
+ if sequence_length != 1:
1371
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1372
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1373
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1374
+ if attention_mask is not None:
1375
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1376
+ if attention_mask.dim() == 2:
1377
+ mask_length = attention_mask.shape[-1]
1378
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1379
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1380
+ elif attention_mask.dim() == 4:
1381
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1382
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1383
+ if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1384
+ offset = cache_position[0]
1385
+ else:
1386
+ offset = 0
1387
+ mask_shape = attention_mask.shape
1388
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1389
+ causal_mask[
1390
+ : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1391
+ ] = mask_slice
1392
+
1393
+ if (
1394
+ self.config._attn_implementation == "sdpa"
1395
+ and attention_mask is not None
1396
+ and attention_mask.device.type == "cuda"
1397
+ ):
1398
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1399
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1400
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1401
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1402
+
1403
+ return causal_mask
1404
+
1405
+
1406
+ class LlamaForCausalLM(LlamaPreTrainedModel):
1407
+ _tied_weights_keys = ["lm_head.weight"]
1408
+
1409
+ def __init__(self, config):
1410
+ super().__init__(config)
1411
+ self.model = LlamaModel(config)
1412
+ self.vocab_size = config.vocab_size
1413
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1414
+
1415
+ # Initialize weights and apply final processing
1416
+ self.post_init()
1417
+
1418
+ def get_input_embeddings(self):
1419
+ return self.model.embed_tokens
1420
+
1421
+ def set_input_embeddings(self, value):
1422
+ self.model.embed_tokens = value
1423
+
1424
+ def get_output_embeddings(self):
1425
+ return self.lm_head
1426
+
1427
+ def set_output_embeddings(self, new_embeddings):
1428
+ self.lm_head = new_embeddings
1429
+
1430
+ def set_decoder(self, decoder):
1431
+ self.model = decoder
1432
+
1433
+ def get_decoder(self):
1434
+ return self.model
1435
+
1436
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1437
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1438
+ def forward(
1439
+ self,
1440
+ input_ids: torch.LongTensor = None,
1441
+ attention_mask: Optional[torch.Tensor] = None,
1442
+ position_ids: Optional[torch.LongTensor] = None,
1443
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1444
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1445
+ labels: Optional[torch.LongTensor] = None,
1446
+ use_cache: Optional[bool] = None,
1447
+ output_attentions: Optional[bool] = None,
1448
+ output_hidden_states: Optional[bool] = None,
1449
+ return_dict: Optional[bool] = None,
1450
+ cache_position: Optional[torch.LongTensor] = None,
1451
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1452
+ r"""
1453
+ Args:
1454
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1455
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1456
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1457
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1458
+
1459
+ Returns:
1460
+
1461
+ Example:
1462
+
1463
+ ```python
1464
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1465
+
1466
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1467
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1468
+
1469
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1470
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1471
+
1472
+ >>> # Generate
1473
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1474
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1475
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1476
+ ```"""
1477
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1478
+ output_hidden_states = (
1479
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1480
+ )
1481
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1482
+
1483
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1484
+ outputs = self.model(
1485
+ input_ids=input_ids,
1486
+ attention_mask=attention_mask,
1487
+ position_ids=position_ids,
1488
+ past_key_values=past_key_values,
1489
+ inputs_embeds=inputs_embeds,
1490
+ use_cache=use_cache,
1491
+ output_attentions=output_attentions,
1492
+ output_hidden_states=output_hidden_states,
1493
+ return_dict=return_dict,
1494
+ cache_position=cache_position,
1495
+ )
1496
+
1497
+ hidden_states = outputs[0]
1498
+ if self.config.pretraining_tp > 1:
1499
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1500
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1501
+ logits = torch.cat(logits, dim=-1)
1502
+ else:
1503
+ logits = self.lm_head(hidden_states)
1504
+ logits = logits.float()
1505
+
1506
+ loss = None
1507
+ if labels is not None:
1508
+ # Shift so that tokens < n predict n
1509
+ shift_logits = logits[..., :-1, :].contiguous()
1510
+ shift_labels = labels[..., 1:].contiguous()
1511
+ # Flatten the tokens
1512
+ loss_fct = CrossEntropyLoss()
1513
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1514
+ shift_labels = shift_labels.view(-1)
1515
+ # Enable model parallelism
1516
+ shift_labels = shift_labels.to(shift_logits.device)
1517
+ loss = loss_fct(shift_logits, shift_labels)
1518
+
1519
+ if not return_dict:
1520
+ output = (logits,) + outputs[1:]
1521
+ return (loss,) + output if loss is not None else output
1522
+
1523
+ return CausalLMOutputWithPast(
1524
+ loss=loss,
1525
+ logits=logits,
1526
+ past_key_values=outputs.past_key_values,
1527
+ hidden_states=outputs.hidden_states,
1528
+ attentions=outputs.attentions,
1529
+ )
1530
+
1531
+ def prepare_inputs_for_generation(
1532
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
1533
+ ):
1534
+ # With static cache, the `past_key_values` is None
1535
+ # TODO joao: standardize interface for the different Cache classes and remove of this if
1536
+ has_static_cache = False
1537
+ if past_key_values is None:
1538
+ past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
1539
+ has_static_cache = past_key_values is not None
1540
+
1541
+ past_length = 0
1542
+ if past_key_values is not None:
1543
+ if isinstance(past_key_values, Cache):
1544
+ past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1545
+ max_cache_length = (
1546
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1547
+ if past_key_values.get_max_length() is not None
1548
+ else None
1549
+ )
1550
+ cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
1551
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1552
+ else:
1553
+ cache_length = past_length = past_key_values[0][0].shape[2]
1554
+ max_cache_length = None
1555
+
1556
+ # Keep only the unprocessed tokens:
1557
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1558
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1559
+ # input)
1560
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1561
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1562
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1563
+ # input_ids based on the past_length.
1564
+ elif past_length < input_ids.shape[1]:
1565
+ input_ids = input_ids[:, past_length:]
1566
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1567
+ else:
1568
+ remove_prefix_length = input_ids.shape[1] - 1
1569
+ input_ids = input_ids[:, remove_prefix_length:]
1570
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1571
+ if (
1572
+ max_cache_length is not None
1573
+ and attention_mask is not None
1574
+ and cache_length + input_ids.shape[1] > max_cache_length
1575
+ ):
1576
+ attention_mask = attention_mask[:, -max_cache_length:]
1577
+
1578
+ position_ids = kwargs.get("position_ids", None)
1579
+ if attention_mask is not None and position_ids is None:
1580
+ # create position_ids on the fly for batch generation
1581
+ position_ids = attention_mask.long().cumsum(-1) - 1
1582
+ position_ids.masked_fill_(attention_mask == 0, 1)
1583
+ if past_key_values:
1584
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1585
+
1586
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1587
+ if inputs_embeds is not None and past_key_values is None:
1588
+ model_inputs = {"inputs_embeds": inputs_embeds}
1589
+ else:
1590
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1591
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1592
+ # TODO: use `next_tokens` directly instead.
1593
+ model_inputs = {"input_ids": input_ids.contiguous()}
1594
+
1595
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1596
+ if cache_position is None:
1597
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
1598
+ else:
1599
+ cache_position = cache_position[-input_length:]
1600
+
1601
+ if has_static_cache:
1602
+ past_key_values = None
1603
+
1604
+ model_inputs.update(
1605
+ {
1606
+ "position_ids": position_ids,
1607
+ "cache_position": cache_position,
1608
+ "past_key_values": past_key_values,
1609
+ "use_cache": kwargs.get("use_cache"),
1610
+ "attention_mask": attention_mask,
1611
+ }
1612
+ )
1613
+ return model_inputs
1614
+
1615
+ @staticmethod
1616
+ def _reorder_cache(past_key_values, beam_idx):
1617
+ reordered_past = ()
1618
+ for layer_past in past_key_values:
1619
+ reordered_past += (
1620
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1621
+ )
1622
+ return reordered_past
1623
+
1624
+
1625
+ @add_start_docstrings(
1626
+ """
1627
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1628
+
1629
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1630
+ (e.g. GPT-2) do.
1631
+
1632
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1633
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1634
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1635
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1636
+ each row of the batch).
1637
+ """,
1638
+ LLAMA_START_DOCSTRING,
1639
+ )
1640
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1641
+ def __init__(self, config):
1642
+ super().__init__(config)
1643
+ self.num_labels = config.num_labels
1644
+ self.model = LlamaModel(config)
1645
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1646
+
1647
+ # Initialize weights and apply final processing
1648
+ self.post_init()
1649
+
1650
+ def get_input_embeddings(self):
1651
+ return self.model.embed_tokens
1652
+
1653
+ def set_input_embeddings(self, value):
1654
+ self.model.embed_tokens = value
1655
+
1656
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1657
+ def forward(
1658
+ self,
1659
+ input_ids: torch.LongTensor = None,
1660
+ attention_mask: Optional[torch.Tensor] = None,
1661
+ position_ids: Optional[torch.LongTensor] = None,
1662
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1663
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1664
+ labels: Optional[torch.LongTensor] = None,
1665
+ use_cache: Optional[bool] = None,
1666
+ output_attentions: Optional[bool] = None,
1667
+ output_hidden_states: Optional[bool] = None,
1668
+ return_dict: Optional[bool] = None,
1669
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1670
+ r"""
1671
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1672
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1673
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1674
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1675
+ """
1676
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1677
+
1678
+ transformer_outputs = self.model(
1679
+ input_ids,
1680
+ attention_mask=attention_mask,
1681
+ position_ids=position_ids,
1682
+ past_key_values=past_key_values,
1683
+ inputs_embeds=inputs_embeds,
1684
+ use_cache=use_cache,
1685
+ output_attentions=output_attentions,
1686
+ output_hidden_states=output_hidden_states,
1687
+ return_dict=return_dict,
1688
+ )
1689
+ hidden_states = transformer_outputs[0]
1690
+ logits = self.score(hidden_states)
1691
+
1692
+ if input_ids is not None:
1693
+ batch_size = input_ids.shape[0]
1694
+ else:
1695
+ batch_size = inputs_embeds.shape[0]
1696
+
1697
+ if self.config.pad_token_id is None and batch_size != 1:
1698
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1699
+ if self.config.pad_token_id is None:
1700
+ sequence_lengths = -1
1701
+ else:
1702
+ if input_ids is not None:
1703
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1704
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1705
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1706
+ sequence_lengths = sequence_lengths.to(logits.device)
1707
+ else:
1708
+ sequence_lengths = -1
1709
+
1710
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1711
+
1712
+ loss = None
1713
+ if labels is not None:
1714
+ labels = labels.to(logits.device)
1715
+ if self.config.problem_type is None:
1716
+ if self.num_labels == 1:
1717
+ self.config.problem_type = "regression"
1718
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1719
+ self.config.problem_type = "single_label_classification"
1720
+ else:
1721
+ self.config.problem_type = "multi_label_classification"
1722
+
1723
+ if self.config.problem_type == "regression":
1724
+ loss_fct = MSELoss()
1725
+ if self.num_labels == 1:
1726
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1727
+ else:
1728
+ loss = loss_fct(pooled_logits, labels)
1729
+ elif self.config.problem_type == "single_label_classification":
1730
+ loss_fct = CrossEntropyLoss()
1731
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1732
+ elif self.config.problem_type == "multi_label_classification":
1733
+ loss_fct = BCEWithLogitsLoss()
1734
+ loss = loss_fct(pooled_logits, labels)
1735
+ if not return_dict:
1736
+ output = (pooled_logits,) + transformer_outputs[1:]
1737
+ return ((loss,) + output) if loss is not None else output
1738
+
1739
+ return SequenceClassifierOutputWithPast(
1740
+ loss=loss,
1741
+ logits=pooled_logits,
1742
+ past_key_values=transformer_outputs.past_key_values,
1743
+ hidden_states=transformer_outputs.hidden_states,
1744
+ attentions=transformer_outputs.attentions,
1745
+ )
1746
+
1747
+
1748
+ @add_start_docstrings(
1749
+ """
1750
+ The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
1751
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1752
+ """,
1753
+ LLAMA_START_DOCSTRING,
1754
+ )
1755
+ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1756
+ base_model_prefix = "transformer"
1757
+
1758
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
1759
+ def __init__(self, config):
1760
+ super().__init__(config)
1761
+ self.transformer = LlamaModel(config)
1762
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1763
+
1764
+ # Initialize weights and apply final processing
1765
+ self.post_init()
1766
+
1767
+ def get_input_embeddings(self):
1768
+ return self.transformer.embed_tokens
1769
+
1770
+ def set_input_embeddings(self, value):
1771
+ self.transformer.embed_tokens = value
1772
+
1773
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1774
+ def forward(
1775
+ self,
1776
+ input_ids: Optional[torch.LongTensor] = None,
1777
+ attention_mask: Optional[torch.FloatTensor] = None,
1778
+ position_ids: Optional[torch.LongTensor] = None,
1779
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1780
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1781
+ start_positions: Optional[torch.LongTensor] = None,
1782
+ end_positions: Optional[torch.LongTensor] = None,
1783
+ output_attentions: Optional[bool] = None,
1784
+ output_hidden_states: Optional[bool] = None,
1785
+ return_dict: Optional[bool] = None,
1786
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1787
+ r"""
1788
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1789
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1790
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1791
+ are not taken into account for computing the loss.
1792
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1793
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1794
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1795
+ are not taken into account for computing the loss.
1796
+ """
1797
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1798
+
1799
+ outputs = self.transformer(
1800
+ input_ids,
1801
+ attention_mask=attention_mask,
1802
+ position_ids=position_ids,
1803
+ past_key_values=past_key_values,
1804
+ inputs_embeds=inputs_embeds,
1805
+ output_attentions=output_attentions,
1806
+ output_hidden_states=output_hidden_states,
1807
+ return_dict=return_dict,
1808
+ )
1809
+
1810
+ sequence_output = outputs[0]
1811
+
1812
+ logits = self.qa_outputs(sequence_output)
1813
+ start_logits, end_logits = logits.split(1, dim=-1)
1814
+ start_logits = start_logits.squeeze(-1).contiguous()
1815
+ end_logits = end_logits.squeeze(-1).contiguous()
1816
+
1817
+ total_loss = None
1818
+ if start_positions is not None and end_positions is not None:
1819
+ # If we are on multi-GPU, split add a dimension
1820
+ if len(start_positions.size()) > 1:
1821
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1822
+ if len(end_positions.size()) > 1:
1823
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1824
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1825
+ ignored_index = start_logits.size(1)
1826
+ start_positions = start_positions.clamp(0, ignored_index)
1827
+ end_positions = end_positions.clamp(0, ignored_index)
1828
+
1829
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1830
+ start_loss = loss_fct(start_logits, start_positions)
1831
+ end_loss = loss_fct(end_logits, end_positions)
1832
+ total_loss = (start_loss + end_loss) / 2
1833
+
1834
+ if not return_dict:
1835
+ output = (start_logits, end_logits) + outputs[2:]
1836
+ return ((total_loss,) + output) if total_loss is not None else output
1837
+
1838
+ return QuestionAnsweringModelOutput(
1839
+ loss=total_loss,
1840
+ start_logits=start_logits,
1841
+ end_logits=end_logits,
1842
+ hidden_states=outputs.hidden_states,
1843
+ attentions=outputs.attentions,
1844
+ )
Unicorn/bunny/model/language_model/llama/tokenization_llama.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ """Tokenization classes for LLaMA."""
22
+ import os
23
+ from shutil import copyfile
24
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
25
+
26
+ import sentencepiece as spm
27
+
28
+ from transformers.convert_slow_tokenizer import import_protobuf
29
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
30
+ from transformers.utils import logging
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers.tokenization_utils_base import TextInput
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
39
+
40
+ SPIECE_UNDERLINE = "▁"
41
+
42
+ B_INST, E_INST = "[INST]", "[/INST]"
43
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
44
+
45
+ # fmt: off
46
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
47
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
48
+ that your responses are socially unbiased and positive in nature.
49
+
50
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
51
+ correct. If you don't know the answer to a question, please don't share false information."""
52
+ # fmt: on
53
+
54
+
55
+ class LlamaTokenizer(PreTrainedTokenizer):
56
+ """
57
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
58
+ no padding token in the original model.
59
+
60
+ Args:
61
+ vocab_file (`str`):
62
+ Path to the vocabulary file.
63
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
64
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
65
+ token instead.
66
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
67
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
68
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
69
+ The end of sequence token.
70
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*):
71
+ A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
72
+ attention mechanisms or loss computation.
73
+ sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
74
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
75
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
76
+ to set:
77
+
78
+ - `enable_sampling`: Enable subword regularization.
79
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
80
+
81
+ - `nbest_size = {0,1}`: No sampling is performed.
82
+ - `nbest_size > 1`: samples from the nbest_size results.
83
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
84
+ using forward-filtering-and-backward-sampling algorithm.
85
+
86
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
87
+ BPE-dropout.
88
+
89
+ add_bos_token (`bool`, *optional*, defaults to `True`):
90
+ Whether or not to add an `bos_token` at the start of sequences.
91
+ add_eos_token (`bool`, *optional*, defaults to `False`):
92
+ Whether or not to add an `eos_token` at the end of sequences.
93
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
94
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
95
+ extra spaces.
96
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
97
+ Whether or not the default system prompt for Llama should be used.
98
+ spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
99
+ Whether or not to add spaces between special tokens.
100
+ legacy (`bool`, *optional*):
101
+ Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
102
+ and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
103
+ example:
104
+
105
+ - `legacy=True`:
106
+ ```python
107
+ >>> from transformers import T5Tokenizer
108
+
109
+ >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True)
110
+ >>> tokenizer.encode("Hello <extra_id_0>.")
111
+ [8774, 32099, 3, 5, 1]
112
+ ```
113
+ - `legacy=False`:
114
+ ```python
115
+ >>> from transformers import T5Tokenizer
116
+
117
+ >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
118
+ >>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
119
+ [8774, 32099, 5, 1]
120
+ ```
121
+ Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
122
+ add_prefix_space (`bool`, *optional*, defaults to `True`):
123
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
124
+ other word.
125
+
126
+ """
127
+
128
+ vocab_files_names = VOCAB_FILES_NAMES
129
+ model_input_names = ["input_ids", "attention_mask"]
130
+
131
+ def __init__(
132
+ self,
133
+ vocab_file,
134
+ unk_token="<unk>",
135
+ bos_token="<s>",
136
+ eos_token="</s>",
137
+ pad_token=None,
138
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
139
+ add_bos_token=True,
140
+ add_eos_token=False,
141
+ clean_up_tokenization_spaces=False,
142
+ use_default_system_prompt=False,
143
+ spaces_between_special_tokens=False,
144
+ legacy=None,
145
+ add_prefix_space=True,
146
+ **kwargs,
147
+ ):
148
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
149
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
150
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
151
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
152
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
153
+
154
+ if legacy is None:
155
+ logger.warning_once(
156
+ f"You are using the default legacy behaviour of the {self.__class__}. This is"
157
+ " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
158
+ " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
159
+ " means, and thoroughly read the reason why this was added as explained in"
160
+ " https://github.com/huggingface/transformers/pull/24565"
161
+ )
162
+ legacy = True
163
+
164
+ self.legacy = legacy
165
+ self.vocab_file = vocab_file
166
+ self.add_bos_token = add_bos_token
167
+ self.add_eos_token = add_eos_token
168
+ self.use_default_system_prompt = use_default_system_prompt
169
+ self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
170
+ self.add_prefix_space = add_prefix_space
171
+
172
+ super().__init__(
173
+ bos_token=bos_token,
174
+ eos_token=eos_token,
175
+ unk_token=unk_token,
176
+ pad_token=pad_token,
177
+ add_bos_token=add_bos_token,
178
+ add_eos_token=add_eos_token,
179
+ sp_model_kwargs=self.sp_model_kwargs,
180
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
181
+ use_default_system_prompt=use_default_system_prompt,
182
+ spaces_between_special_tokens=spaces_between_special_tokens,
183
+ legacy=legacy,
184
+ add_prefix_space=add_prefix_space,
185
+ **kwargs,
186
+ )
187
+
188
+ @property
189
+ def unk_token_length(self):
190
+ return len(self.sp_model.encode(str(self.unk_token)))
191
+
192
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
193
+ def get_spm_processor(self, from_slow=False):
194
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
195
+ if self.legacy or from_slow: # no dependency on protobuf
196
+ tokenizer.Load(self.vocab_file)
197
+ return tokenizer
198
+
199
+ with open(self.vocab_file, "rb") as f:
200
+ sp_model = f.read()
201
+ model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
202
+ model = model_pb2.ModelProto.FromString(sp_model)
203
+ normalizer_spec = model_pb2.NormalizerSpec()
204
+ normalizer_spec.add_dummy_prefix = False
205
+ model.normalizer_spec.MergeFrom(normalizer_spec)
206
+ sp_model = model.SerializeToString()
207
+ tokenizer.LoadFromSerializedProto(sp_model)
208
+ return tokenizer
209
+
210
+ def __getstate__(self):
211
+ state = self.__dict__.copy()
212
+ state["sp_model"] = None
213
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
214
+ return state
215
+
216
+ def __setstate__(self, d):
217
+ self.__dict__ = d
218
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
219
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
220
+
221
+ @property
222
+ def vocab_size(self):
223
+ """Returns vocab size"""
224
+ return self.sp_model.get_piece_size()
225
+
226
+ def get_vocab(self):
227
+ """Returns vocab as a dict"""
228
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
229
+ vocab.update(self.added_tokens_encoder)
230
+ return vocab
231
+
232
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
233
+ def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
234
+ """
235
+ Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
236
+ first token is special.
237
+ """
238
+ if self.legacy or len(text) == 0:
239
+ return super().tokenize(text, **kwargs)
240
+
241
+ text = text.replace(SPIECE_UNDERLINE, " ")
242
+ if self.add_prefix_space:
243
+ text = SPIECE_UNDERLINE + text
244
+
245
+ tokens = super().tokenize(text, **kwargs)
246
+
247
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
248
+ tokens = tokens[1:]
249
+ return tokens
250
+
251
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
252
+ def _tokenize(self, text, **kwargs):
253
+ """
254
+ Returns a tokenized string.
255
+
256
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
257
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
258
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
259
+ `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
260
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
261
+ """
262
+ tokens = self.sp_model.encode(text, out_type=str)
263
+ if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
264
+ return tokens
265
+
266
+ # 1. Encode string + prefix ex: "<unk> Hey"
267
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
268
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
269
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
270
+
271
+ def _convert_token_to_id(self, token):
272
+ """Converts a token (str) in an id using the vocab."""
273
+ return self.sp_model.piece_to_id(token)
274
+
275
+ def _convert_id_to_token(self, index):
276
+ """Converts an index (integer) in a token (str) using the vocab."""
277
+ token = self.sp_model.IdToPiece(index)
278
+ return token
279
+
280
+ def convert_tokens_to_string(self, tokens):
281
+ """Converts a sequence of tokens (string) in a single string."""
282
+ # since we manually add the prefix space, we have to remove it when decoding
283
+ if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
284
+ tokens[0] = tokens[0][1:]
285
+
286
+ current_sub_tokens = []
287
+ out_string = ""
288
+ prev_is_special = False
289
+ for i, token in enumerate(tokens):
290
+ # make sure that special tokens are not decoded using sentencepiece model
291
+ if token in self.all_special_tokens:
292
+ if not prev_is_special and i != 0 and self.legacy:
293
+ out_string += " "
294
+ out_string += self.sp_model.decode(current_sub_tokens) + token
295
+ prev_is_special = True
296
+ current_sub_tokens = []
297
+ else:
298
+ if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE):
299
+ out_string += " "
300
+ current_sub_tokens.append(token)
301
+ prev_is_special = False
302
+ out_string += self.sp_model.decode(current_sub_tokens)
303
+ return out_string
304
+
305
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
306
+ """
307
+ Save the vocabulary and special tokens file to a directory.
308
+
309
+ Args:
310
+ save_directory (`str`):
311
+ The directory in which to save the vocabulary.
312
+
313
+ Returns:
314
+ `Tuple(str)`: Paths to the files saved.
315
+ """
316
+ if not os.path.isdir(save_directory):
317
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
318
+ return
319
+ out_vocab_file = os.path.join(
320
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
321
+ )
322
+
323
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
324
+ copyfile(self.vocab_file, out_vocab_file)
325
+ elif not os.path.isfile(self.vocab_file):
326
+ with open(out_vocab_file, "wb") as fi:
327
+ content_spiece_model = self.sp_model.serialized_model_proto()
328
+ fi.write(content_spiece_model)
329
+
330
+ return (out_vocab_file,)
331
+
332
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
333
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
334
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
335
+
336
+ output = bos_token_id + token_ids_0 + eos_token_id
337
+
338
+ if token_ids_1 is not None:
339
+ output = output + bos_token_id + token_ids_1 + eos_token_id
340
+
341
+ return output
342
+
343
+ def get_special_tokens_mask(
344
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
345
+ ) -> List[int]:
346
+ """
347
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
348
+ special tokens using the tokenizer `prepare_for_model` method.
349
+
350
+ Args:
351
+ token_ids_0 (`List[int]`):
352
+ List of IDs.
353
+ token_ids_1 (`List[int]`, *optional*):
354
+ Optional second list of IDs for sequence pairs.
355
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
356
+ Whether or not the token list is already formatted with special tokens for the model.
357
+
358
+ Returns:
359
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
360
+ """
361
+ if already_has_special_tokens:
362
+ return super().get_special_tokens_mask(
363
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
364
+ )
365
+
366
+ bos_token_id = [1] if self.add_bos_token else []
367
+ eos_token_id = [1] if self.add_eos_token else []
368
+
369
+ if token_ids_1 is None:
370
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
371
+ return (
372
+ bos_token_id
373
+ + ([0] * len(token_ids_0))
374
+ + eos_token_id
375
+ + bos_token_id
376
+ + ([0] * len(token_ids_1))
377
+ + eos_token_id
378
+ )
379
+
380
+ def create_token_type_ids_from_sequences(
381
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
382
+ ) -> List[int]:
383
+ """
384
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
385
+ sequence pair mask has the following format:
386
+
387
+ ```
388
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
389
+ | first sequence | second sequence |
390
+ ```
391
+
392
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
393
+
394
+ Args:
395
+ token_ids_0 (`List[int]`):
396
+ List of ids.
397
+ token_ids_1 (`List[int]`, *optional*):
398
+ Optional second list of IDs for sequence pairs.
399
+
400
+ Returns:
401
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
402
+ """
403
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
404
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
405
+
406
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
407
+
408
+ if token_ids_1 is not None:
409
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
410
+
411
+ return output
412
+
413
+ @property
414
+ def default_chat_template(self):
415
+ """
416
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
417
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
418
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
419
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
420
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
421
+ to fine-tune a model with more flexible role ordering!
422
+
423
+ The output should look something like:
424
+
425
+ <bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
426
+ <bos>[INST] Prompt [/INST]
427
+
428
+ The reference for this chat template is [this code
429
+ snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
430
+ in the original repository.
431
+ """
432
+ logger.warning_once(
433
+ "\nNo chat template is defined for this tokenizer - using the default template "
434
+ f"for the {self.__class__.__name__} class. If the default is not appropriate for "
435
+ "your model, please set `tokenizer.chat_template` to an appropriate template. "
436
+ "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
437
+ )
438
+ template = (
439
+ "{% if messages[0]['role'] == 'system' %}"
440
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
441
+ "{% set system_message = messages[0]['content'] %}"
442
+ "{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
443
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
444
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
445
+ "{% else %}"
446
+ "{% set loop_messages = messages %}"
447
+ "{% set system_message = false %}"
448
+ "{% endif %}"
449
+ "{% for message in loop_messages %}" # Loop over all non-system messages
450
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
451
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
452
+ "{% endif %}"
453
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
454
+ "{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
455
+ "{% else %}"
456
+ "{% set content = message['content'] %}"
457
+ "{% endif %}"
458
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
459
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
460
+ "{% elif message['role'] == 'system' %}"
461
+ "{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
462
+ "{% elif message['role'] == 'assistant' %}"
463
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
464
+ "{% endif %}"
465
+ "{% endfor %}"
466
+ )
467
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
468
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
469
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
470
+
471
+ return template
Unicorn/bunny/model/language_model/llama/tokenization_llama_fast.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ from shutil import copyfile
17
+ from typing import Optional, Tuple
18
+
19
+ from tokenizers import processors
20
+
21
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
22
+ from transformers.utils import is_sentencepiece_available, logging
23
+ from transformers.utils.versions import require_version
24
+
25
+
26
+ require_version("tokenizers>=0.13.3")
27
+
28
+ if is_sentencepiece_available():
29
+ from .tokenization_llama import LlamaTokenizer
30
+ else:
31
+ LlamaTokenizer = None
32
+
33
+ logger = logging.get_logger(__name__)
34
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
35
+
36
+ B_INST, E_INST = "[INST]", "[/INST]"
37
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
38
+
39
+ # fmt: off
40
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
41
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
42
+ that your responses are socially unbiased and positive in nature.
43
+
44
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
45
+ correct. If you don't know the answer to a question, please don't share false information."""
46
+ # fmt: on
47
+
48
+
49
+ class LlamaTokenizerFast(PreTrainedTokenizerFast):
50
+ """
51
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
52
+
53
+ This uses notably ByteFallback and no normalization.
54
+
55
+ ```python
56
+ >>> from transformers import LlamaTokenizerFast
57
+
58
+ >>> tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
59
+ >>> tokenizer.encode("Hello this is a test")
60
+ [1, 15043, 445, 338, 263, 1243]
61
+ ```
62
+
63
+ If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
64
+ call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
65
+ values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
66
+ [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
67
+
68
+
69
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
70
+ refer to this superclass for more information regarding those methods.
71
+
72
+ Args:
73
+ vocab_file (`str`, *optional*):
74
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
75
+ contains the vocabulary necessary to instantiate a tokenizer.
76
+ tokenizer_file (`str`, *optional*):
77
+ [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
78
+ contains everything needed to load the tokenizer.
79
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
80
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
81
+ extra spaces.
82
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
83
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
84
+ token instead.
85
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
86
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
87
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
88
+ The end of sequence token.
89
+ add_bos_token (`bool`, *optional*, defaults to `True`):
90
+ Whether or not to add an `bos_token` at the start of sequences.
91
+ add_eos_token (`bool`, *optional*, defaults to `False`):
92
+ Whether or not to add an `eos_token` at the end of sequences.
93
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
94
+ Whether or not the default system prompt for Llama should be used.
95
+ add_prefix_space (`bool`, *optional*):
96
+ Whether or not the tokenizer should automatically add a prefix space
97
+ """
98
+
99
+ vocab_files_names = VOCAB_FILES_NAMES
100
+ slow_tokenizer_class = LlamaTokenizer
101
+ padding_side = "left"
102
+ model_input_names = ["input_ids", "attention_mask"]
103
+
104
+ def __init__(
105
+ self,
106
+ vocab_file=None,
107
+ tokenizer_file=None,
108
+ clean_up_tokenization_spaces=False,
109
+ unk_token="<unk>",
110
+ bos_token="<s>",
111
+ eos_token="</s>",
112
+ add_bos_token=True,
113
+ add_eos_token=False,
114
+ use_default_system_prompt=False,
115
+ add_prefix_space=None,
116
+ **kwargs,
117
+ ):
118
+ if add_prefix_space is not None:
119
+ logger.warning_once(
120
+ "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"
121
+ )
122
+ kwargs["from_slow"] = True
123
+
124
+ super().__init__(
125
+ vocab_file=vocab_file,
126
+ tokenizer_file=tokenizer_file,
127
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
128
+ unk_token=unk_token,
129
+ bos_token=bos_token,
130
+ eos_token=eos_token,
131
+ add_bos_token=add_bos_token,
132
+ add_eos_token=add_eos_token,
133
+ use_default_system_prompt=use_default_system_prompt,
134
+ **kwargs,
135
+ )
136
+ self._add_bos_token = add_bos_token
137
+ self._add_eos_token = add_eos_token
138
+ self.update_post_processor()
139
+ self.use_default_system_prompt = use_default_system_prompt
140
+ self.vocab_file = vocab_file
141
+
142
+ @property
143
+ def can_save_slow_tokenizer(self) -> bool:
144
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
145
+
146
+ def update_post_processor(self):
147
+ """
148
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
149
+ """
150
+ bos = self.bos_token
151
+ bos_token_id = self.bos_token_id
152
+ if bos is None and self.add_bos_token:
153
+ raise ValueError("add_bos_token = True but bos_token = None")
154
+
155
+ eos = self.eos_token
156
+ eos_token_id = self.eos_token_id
157
+ if eos is None and self.add_eos_token:
158
+ raise ValueError("add_eos_token = True but eos_token = None")
159
+
160
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
161
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
162
+
163
+ special_tokens = []
164
+ if self.add_bos_token:
165
+ special_tokens.append((bos, bos_token_id))
166
+ if self.add_eos_token:
167
+ special_tokens.append((eos, eos_token_id))
168
+ self._tokenizer.post_processor = processors.TemplateProcessing(
169
+ single=single, pair=pair, special_tokens=special_tokens
170
+ )
171
+
172
+ @property
173
+ def add_eos_token(self):
174
+ return self._add_eos_token
175
+
176
+ @property
177
+ def add_bos_token(self):
178
+ return self._add_bos_token
179
+
180
+ @add_eos_token.setter
181
+ def add_eos_token(self, value):
182
+ self._add_eos_token = value
183
+ self.update_post_processor()
184
+
185
+ @add_bos_token.setter
186
+ def add_bos_token(self, value):
187
+ self._add_bos_token = value
188
+ self.update_post_processor()
189
+
190
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
191
+ if not self.can_save_slow_tokenizer:
192
+ raise ValueError(
193
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
194
+ "tokenizer."
195
+ )
196
+
197
+ if not os.path.isdir(save_directory):
198
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
199
+ return
200
+ out_vocab_file = os.path.join(
201
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
202
+ )
203
+
204
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
205
+ copyfile(self.vocab_file, out_vocab_file)
206
+
207
+ return (out_vocab_file,)
208
+
209
+ @property
210
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
211
+ def default_chat_template(self):
212
+ """
213
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
214
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
215
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
216
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
217
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
218
+ to fine-tune a model with more flexible role ordering!
219
+
220
+ The output should look something like:
221
+
222
+ <bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
223
+ <bos>[INST] Prompt [/INST]
224
+
225
+ The reference for this chat template is [this code
226
+ snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
227
+ in the original repository.
228
+ """
229
+ logger.warning_once(
230
+ "\nNo chat template is defined for this tokenizer - using the default template "
231
+ f"for the {self.__class__.__name__} class. If the default is not appropriate for "
232
+ "your model, please set `tokenizer.chat_template` to an appropriate template. "
233
+ "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
234
+ )
235
+ template = (
236
+ "{% if messages[0]['role'] == 'system' %}"
237
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
238
+ "{% set system_message = messages[0]['content'] %}"
239
+ "{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
240
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
241
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
242
+ "{% else %}"
243
+ "{% set loop_messages = messages %}"
244
+ "{% set system_message = false %}"
245
+ "{% endif %}"
246
+ "{% for message in loop_messages %}" # Loop over all non-system messages
247
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
248
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
249
+ "{% endif %}"
250
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
251
+ "{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
252
+ "{% else %}"
253
+ "{% set content = message['content'] %}"
254
+ "{% endif %}"
255
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
256
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
257
+ "{% elif message['role'] == 'system' %}"
258
+ "{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
259
+ "{% elif message['role'] == 'assistant' %}"
260
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
261
+ "{% endif %}"
262
+ "{% endfor %}"
263
+ )
264
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
265
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
266
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
267
+
268
+ return template
269
+
270
+ # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
271
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
272
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
273
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
274
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
275
+
276
+ output = bos_token_id + token_ids_0 + eos_token_id
277
+
278
+ if token_ids_1 is not None:
279
+ output = output + bos_token_id + token_ids_1 + eos_token_id
280
+
281
+ return output
Unicorn/bunny/model/language_model/minicpm/__pycache__/configuration_minicpm.cpython-310.pyc ADDED
Binary file (8.05 kB). View file
 
Unicorn/bunny/model/language_model/minicpm/__pycache__/modeling_minicpm.cpython-310.pyc ADDED
Binary file (45 kB). View file
 
Unicorn/bunny/model/language_model/minicpm/configuration_minicpm.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ MiniCPM model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
29
+
30
+
31
+ class MiniCPMConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
34
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35
+ defaults will yield a similar configuration to that of the MiniCPM-7B.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32000):
43
+ Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`MiniCPMModel`]
45
+ hidden_size (`int`, *optional*, defaults to 4096):
46
+ Dimension of the hidden representations.
47
+ intermediate_size (`int`, *optional*, defaults to 11008):
48
+ Dimension of the MLP representations.
49
+ num_hidden_layers (`int`, *optional*, defaults to 32):
50
+ Number of hidden layers in the Transformer decoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 32):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ num_key_value_heads (`int`, *optional*):
54
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58
+ by meanpooling all the original heads within that group. For more details checkout [this
59
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60
+ `num_attention_heads`.
61
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
+ The non-linear activation function (function or string) in the decoder.
63
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
64
+ The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
65
+ MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69
+ The epsilon used by the rms normalization layers.
70
+ use_cache (`bool`, *optional*, defaults to `True`):
71
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
72
+ relevant if `config.is_decoder=True`.
73
+ pad_token_id (`int`, *optional*):
74
+ Padding token id.
75
+ bos_token_id (`int`, *optional*, defaults to 1):
76
+ Beginning of stream token id.
77
+ eos_token_id (`int`, *optional*, defaults to 2):
78
+ End of stream token id.
79
+ pretraining_tp (`int`, *optional*, defaults to 1):
80
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
81
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
82
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
83
+ issue](https://github.com/pytorch/pytorch/issues/76232).
84
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
85
+ Whether to tie weight embeddings
86
+ rope_theta (`float`, *optional*, defaults to 10000.0):
87
+ The base period of the RoPE embeddings.
88
+ rope_scaling (`Dict`, *optional*):
89
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
90
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
91
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
92
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
93
+ these scaling strategies behave:
94
+ https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
95
+ experimental feature, subject to breaking API changes in future versions.
96
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
97
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
98
+ attention_dropout (`float`, *optional*, defaults to 0.0):
99
+ The dropout ratio for the attention probabilities.
100
+
101
+ ```python
102
+ >>> from transformers import MiniCPMModel, MiniCPMConfig
103
+
104
+ >>> # Initializing a MiniCPM minicpm-7b style configuration
105
+ >>> configuration = MiniCPMConfig()
106
+
107
+ >>> # Initializing a model from the minicpm-7b style configuration
108
+ >>> model = MiniCPMModel(configuration)
109
+
110
+ >>> # Accessing the model configuration
111
+ >>> configuration = model.config
112
+ ```"""
113
+
114
+ model_type = "minicpm"
115
+ keys_to_ignore_at_inference = ["past_key_values"]
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_size=32000,
120
+ hidden_size=4096,
121
+ intermediate_size=11008,
122
+ num_hidden_layers=32,
123
+ num_attention_heads=32,
124
+ num_key_value_heads=None,
125
+ hidden_act="silu",
126
+ max_position_embeddings=2048,
127
+ initializer_range=0.02,
128
+ rms_norm_eps=1e-6,
129
+ use_cache=True,
130
+ pad_token_id=None,
131
+ bos_token_id=1,
132
+ eos_token_id=2,
133
+ pretraining_tp=1,
134
+ tie_word_embeddings=True,
135
+ rope_theta=10000.0,
136
+ rope_scaling=None,
137
+ attention_bias=False,
138
+ attention_dropout=0.0,
139
+ scale_emb=1,
140
+ dim_model_base=1,
141
+ scale_depth=1,
142
+ **kwargs,
143
+ ):
144
+ self.vocab_size = vocab_size
145
+ self.max_position_embeddings = max_position_embeddings
146
+ self.hidden_size = hidden_size
147
+ self.intermediate_size = intermediate_size
148
+ self.num_hidden_layers = num_hidden_layers
149
+ self.num_attention_heads = num_attention_heads
150
+
151
+ # for backward compatibility
152
+ if num_key_value_heads is None:
153
+ num_key_value_heads = num_attention_heads
154
+
155
+ self.num_key_value_heads = num_key_value_heads
156
+ self.hidden_act = hidden_act
157
+ self.initializer_range = initializer_range
158
+ self.rms_norm_eps = rms_norm_eps
159
+ self.pretraining_tp = pretraining_tp
160
+ self.use_cache = use_cache
161
+ self.rope_theta = rope_theta
162
+ self.rope_scaling = rope_scaling
163
+ self._rope_scaling_validation()
164
+ self.attention_bias = attention_bias
165
+ self.attention_dropout = attention_dropout
166
+ self.scale_emb = scale_emb
167
+ self.dim_model_base = dim_model_base
168
+ self.scale_depth = scale_depth
169
+
170
+ super().__init__(
171
+ pad_token_id=pad_token_id,
172
+ bos_token_id=bos_token_id,
173
+ eos_token_id=eos_token_id,
174
+ tie_word_embeddings=tie_word_embeddings,
175
+ **kwargs,
176
+ )
177
+ try:
178
+ import flash_attn
179
+ self._attn_implementation = "flash_attention_2"
180
+ except:
181
+ pass
182
+
183
+ def _rope_scaling_validation(self):
184
+ """
185
+ Validate the `rope_scaling` configuration.
186
+ """
187
+ if self.rope_scaling is None:
188
+ return
189
+
190
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
191
+ raise ValueError(
192
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
193
+ f"got {self.rope_scaling}"
194
+ )
195
+ rope_scaling_type = self.rope_scaling.get("type", None)
196
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
197
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
198
+ raise ValueError(
199
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
200
+ )
201
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
202
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
Unicorn/bunny/model/language_model/minicpm/modeling_minicpm.py ADDED
@@ -0,0 +1,1456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch MiniCPM model."""
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union, Dict
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.modeling_attn_mask_utils import (
34
+ AttentionMaskConverter,
35
+ _prepare_4d_attention_mask,
36
+ _prepare_4d_causal_attention_mask,
37
+ _prepare_4d_causal_attention_mask_for_sdpa,
38
+ )
39
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
40
+ from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
42
+ from transformers.utils import (
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ is_flash_attn_2_available,
46
+ is_flash_attn_greater_or_equal_2_10,
47
+ logging,
48
+ replace_return_docstrings,
49
+ )
50
+ from transformers.utils.import_utils import is_torch_fx_available
51
+ from .configuration_minicpm import MiniCPMConfig
52
+ import re
53
+
54
+ try:
55
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
56
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
+ except:
58
+ pass
59
+
60
+
61
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
62
+ # It means that the function will not be traced through and simply appear as a node in the graph.
63
+ if is_torch_fx_available():
64
+ if not is_torch_greater_or_equal_than_1_13:
65
+ import torch.fx
66
+
67
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
68
+
69
+
70
+ logger = logging.get_logger(__name__)
71
+
72
+ _CONFIG_FOR_DOC = "MiniCPMConfig"
73
+
74
+
75
+ def _get_unpad_data(attention_mask):
76
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
77
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
78
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
79
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
80
+ return (
81
+ indices,
82
+ cu_seqlens,
83
+ max_seqlen_in_batch,
84
+ )
85
+
86
+
87
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
88
+ warnings.warn(
89
+ "Calling `transformers.models.minicpm.modeling_minicpm._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
90
+ )
91
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
92
+
93
+
94
+ def _make_causal_mask(
95
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
96
+ ):
97
+ warnings.warn(
98
+ "Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask"
99
+ )
100
+ return AttentionMaskConverter._make_causal_mask(
101
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
102
+ )
103
+
104
+ # @torch.jit.script # type: ignore
105
+ def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
106
+ old_dtype = hidden.dtype
107
+ variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
108
+ hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
109
+ return hidden * weight
110
+
111
+
112
+ class MiniCPMRMSNorm(nn.Module):
113
+ def __init__(self, hidden_size, eps=1e-6):
114
+ """
115
+ MiniCPMRMSNorm is equivalent to T5LayerNorm
116
+ """
117
+ super().__init__()
118
+ self.weight = nn.Parameter(torch.ones(hidden_size))
119
+ self.variance_epsilon = eps
120
+
121
+ def forward(self, hidden_states):
122
+ return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
123
+
124
+
125
+ ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm)
126
+
127
+
128
+ class MiniCPMRotaryEmbedding(nn.Module):
129
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
130
+ super().__init__()
131
+
132
+ self.dim = dim
133
+ self.max_position_embeddings = max_position_embeddings
134
+ self.base = base
135
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
136
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
137
+
138
+ # Build here to make `torch.jit.trace` work.
139
+ self._set_cos_sin_cache(
140
+ # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
141
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
142
+ )
143
+
144
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
145
+ self.max_seq_len_cached = seq_len
146
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
147
+ freqs = torch.outer(t, self.inv_freq)
148
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
149
+ emb = torch.cat((freqs, freqs), dim=-1)
150
+
151
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
152
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
153
+
154
+ def forward(self, x, seq_len=None):
155
+ # x: [bs, num_attention_heads, seq_len, head_size]
156
+ if seq_len > self.max_seq_len_cached:
157
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
158
+
159
+ return (
160
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
161
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
162
+ )
163
+
164
+
165
+ class MiniCPMLinearScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
166
+ """MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
167
+
168
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
169
+ self.scaling_factor = scaling_factor
170
+ super().__init__(dim, max_position_embeddings, base, device)
171
+
172
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
173
+ self.max_seq_len_cached = seq_len
174
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
175
+ t = t / self.scaling_factor
176
+
177
+ freqs = torch.outer(t, self.inv_freq)
178
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
179
+ emb = torch.cat((freqs, freqs), dim=-1)
180
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
181
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
182
+
183
+
184
+ class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
185
+ """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
186
+
187
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
188
+ self.scaling_factor = scaling_factor
189
+ super().__init__(dim, max_position_embeddings, base, device)
190
+
191
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
192
+ self.max_seq_len_cached = seq_len
193
+
194
+ if seq_len > self.max_position_embeddings:
195
+ base = self.base * (
196
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
197
+ ) ** (self.dim / (self.dim - 2))
198
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
199
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
200
+
201
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
202
+
203
+ freqs = torch.outer(t, self.inv_freq)
204
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
205
+ emb = torch.cat((freqs, freqs), dim=-1)
206
+
207
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
208
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
209
+
210
+
211
+ def rotate_half(x):
212
+ """Rotates half the hidden dims of the input."""
213
+ x1 = x[..., : x.shape[-1] // 2]
214
+ x2 = x[..., x.shape[-1] // 2 :]
215
+ return torch.cat((-x2, x1), dim=-1)
216
+
217
+
218
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
219
+ """Applies Rotary Position Embedding to the query and key tensors.
220
+
221
+ Args:
222
+ q (`torch.Tensor`): The query tensor.
223
+ k (`torch.Tensor`): The key tensor.
224
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
225
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
226
+ position_ids (`torch.Tensor`):
227
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
228
+ used to pass offsetted position ids when working with a KV-cache.
229
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
230
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
231
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
232
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
233
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
234
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
235
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
236
+ Returns:
237
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
238
+ """
239
+ # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
240
+ # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
241
+ # q_embed = (q * cos) + (rotate_half(q) * sin)
242
+ # k_embed = (k * cos) + (rotate_half(k) * sin)
243
+ orig_dtype = k.dtype
244
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
245
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
246
+ q_fp32 = q.to(dtype=torch.float32, device=q.device)
247
+ k_fp32 = k.to(dtype=torch.float32, device=k.device)
248
+ q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
249
+ k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
250
+ return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
251
+
252
+ class MiniCPMMLP(nn.Module):
253
+ def __init__(self, config):
254
+ super().__init__()
255
+ self.config = config
256
+ self.hidden_size = config.hidden_size
257
+ self.intermediate_size = config.intermediate_size
258
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
259
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
260
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
261
+ self.act_fn = ACT2FN[config.hidden_act]
262
+
263
+ def forward(self, x):
264
+ if self.config.pretraining_tp > 1:
265
+ slice = self.intermediate_size // self.config.pretraining_tp
266
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
267
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
268
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
269
+
270
+ gate_proj = torch.cat(
271
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
272
+ )
273
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
274
+
275
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
276
+ down_proj = [
277
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
278
+ ]
279
+ down_proj = sum(down_proj)
280
+ else:
281
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
282
+
283
+ return down_proj
284
+
285
+
286
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
287
+ """
288
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
289
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
290
+ """
291
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
292
+ if n_rep == 1:
293
+ return hidden_states
294
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
295
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
296
+
297
+
298
+
299
+ class MiniCPMAttention(nn.Module):
300
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
301
+
302
+ def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):
303
+ super().__init__()
304
+ self.config = config
305
+ self.layer_idx = layer_idx
306
+ if layer_idx is None:
307
+ logger.warning_once(
308
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
309
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
310
+ "when creating this class."
311
+ )
312
+
313
+ self.attention_dropout = config.attention_dropout
314
+ self.hidden_size = config.hidden_size
315
+ self.num_heads = config.num_attention_heads
316
+ self.head_dim = self.hidden_size // self.num_heads
317
+ self.num_key_value_heads = config.num_key_value_heads
318
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
319
+ self.max_position_embeddings = config.max_position_embeddings
320
+ self.rope_theta = config.rope_theta
321
+ self.is_causal = True
322
+
323
+ if (self.head_dim * self.num_heads) != self.hidden_size:
324
+ raise ValueError(
325
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
326
+ f" and `num_heads`: {self.num_heads})."
327
+ )
328
+
329
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
330
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
331
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
332
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
333
+ self._init_rope()
334
+
335
+ def _init_rope(self):
336
+ if self.config.rope_scaling is None:
337
+ self.rotary_emb = MiniCPMRotaryEmbedding(
338
+ self.head_dim,
339
+ max_position_embeddings=self.max_position_embeddings,
340
+ base=self.rope_theta,
341
+ )
342
+ else:
343
+ scaling_type = self.config.rope_scaling["type"]
344
+ scaling_factor = self.config.rope_scaling["factor"]
345
+ if scaling_type == "linear":
346
+ self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
347
+ self.head_dim,
348
+ max_position_embeddings=self.max_position_embeddings,
349
+ scaling_factor=scaling_factor,
350
+ base=self.rope_theta,
351
+ )
352
+ elif scaling_type == "dynamic":
353
+ self.rotary_emb = MiniCPMDynamicNTKScalingRotaryEmbedding(
354
+ self.head_dim,
355
+ max_position_embeddings=self.max_position_embeddings,
356
+ scaling_factor=scaling_factor,
357
+ base=self.rope_theta,
358
+ )
359
+ else:
360
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
361
+
362
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
363
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
364
+
365
+ def forward(
366
+ self,
367
+ hidden_states: torch.Tensor,
368
+ attention_mask: Optional[torch.Tensor] = None,
369
+ position_ids: Optional[torch.LongTensor] = None,
370
+ past_key_value: Optional[Cache] = None,
371
+ output_attentions: bool = False,
372
+ use_cache: bool = False,
373
+ **kwargs,
374
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
375
+ if "padding_mask" in kwargs:
376
+ warnings.warn(
377
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
378
+ )
379
+
380
+ bsz, q_len, _ = hidden_states.size()
381
+
382
+ if self.config.pretraining_tp > 1:
383
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
384
+ query_slices = self.q_proj.weight.split(
385
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
386
+ )
387
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
388
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
389
+
390
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
391
+ query_states = torch.cat(query_states, dim=-1)
392
+
393
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
394
+ key_states = torch.cat(key_states, dim=-1)
395
+
396
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
397
+ value_states = torch.cat(value_states, dim=-1)
398
+
399
+ else:
400
+ query_states = self.q_proj(hidden_states)
401
+ key_states = self.k_proj(hidden_states)
402
+ value_states = self.v_proj(hidden_states)
403
+
404
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
405
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
406
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
407
+
408
+ kv_seq_len = key_states.shape[-2]
409
+ if past_key_value is not None:
410
+ if self.layer_idx is None:
411
+ raise ValueError(
412
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
413
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
414
+ "with a layer index."
415
+ )
416
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
417
+ cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
418
+
419
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
420
+
421
+ if past_key_value is not None:
422
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
423
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
424
+
425
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
426
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
427
+
428
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
429
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
430
+ raise ValueError(
431
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
432
+ f" {attn_weights.size()}"
433
+ )
434
+
435
+ if attention_mask is not None:
436
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
437
+ raise ValueError(
438
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
439
+ )
440
+ attn_weights = attn_weights + attention_mask
441
+
442
+ # upcast attention to fp32
443
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
444
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
445
+ attn_output = torch.matmul(attn_weights, value_states)
446
+
447
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
448
+ raise ValueError(
449
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
450
+ f" {attn_output.size()}"
451
+ )
452
+
453
+ attn_output = attn_output.transpose(1, 2).contiguous()
454
+
455
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
456
+
457
+ if self.config.pretraining_tp > 1:
458
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
459
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
460
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
461
+ else:
462
+ attn_output = self.o_proj(attn_output)
463
+
464
+ if not output_attentions:
465
+ attn_weights = None
466
+
467
+ return attn_output, attn_weights, past_key_value
468
+
469
+
470
+ class MiniCPMFlashAttention2(MiniCPMAttention):
471
+ """
472
+ MiniCPM flash attention module. This module inherits from `MiniCPMAttention` as the weights of the module stays
473
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
474
+ flash attention and deal with padding tokens in case the input contains any of them.
475
+ """
476
+
477
+ def __init__(self, *args, **kwargs):
478
+ super().__init__(*args, **kwargs)
479
+
480
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
481
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
482
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
483
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ attention_mask: Optional[torch.LongTensor] = None,
489
+ position_ids: Optional[torch.LongTensor] = None,
490
+ past_key_value: Optional[Cache] = None,
491
+ output_attentions: bool = False,
492
+ use_cache: bool = False,
493
+ **kwargs,
494
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
495
+ # MiniCPMFlashAttention2 attention does not support output_attentions
496
+ if "padding_mask" in kwargs:
497
+ warnings.warn(
498
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
499
+ )
500
+
501
+ # overwrite attention_mask with padding_mask
502
+ attention_mask = kwargs.pop("padding_mask")
503
+
504
+ output_attentions = False
505
+
506
+ bsz, q_len, _ = hidden_states.size()
507
+
508
+ query_states = self.q_proj(hidden_states)
509
+ key_states = self.k_proj(hidden_states)
510
+ value_states = self.v_proj(hidden_states)
511
+
512
+ # Flash attention requires the input to have the shape
513
+ # batch_size x seq_length x head_dim x hidden_dim
514
+ # therefore we just need to keep the original shape
515
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
516
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
517
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
518
+
519
+ kv_seq_len = key_states.shape[-2]
520
+ if past_key_value is not None:
521
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
522
+ cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
523
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
524
+
525
+ if past_key_value is not None:
526
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
527
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
528
+
529
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
530
+ # to be able to avoid many of these transpose/reshape/view.
531
+ query_states = query_states.transpose(1, 2)
532
+ key_states = key_states.transpose(1, 2)
533
+ value_states = value_states.transpose(1, 2)
534
+
535
+ dropout_rate = self.attention_dropout if self.training else 0.0
536
+
537
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
538
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
539
+ # cast them back in the correct dtype just to be sure everything works as expected.
540
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
541
+ # in fp32. (MiniCPMRMSNorm handles it correctly)
542
+
543
+ input_dtype = query_states.dtype
544
+ if input_dtype == torch.float32:
545
+ # Handle the case where the model is quantized
546
+ if hasattr(self.config, "_pre_quantization_dtype"):
547
+ target_dtype = self.config._pre_quantization_dtype
548
+ else:
549
+ target_dtype = self.q_proj.weight.dtype
550
+
551
+ logger.warning_once(
552
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
553
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
554
+ f" {target_dtype}."
555
+ )
556
+
557
+ query_states = query_states.to(target_dtype)
558
+ key_states = key_states.to(target_dtype)
559
+ value_states = value_states.to(target_dtype)
560
+
561
+ attn_output = self._flash_attention_forward(
562
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
563
+ )
564
+
565
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
566
+ attn_output = self.o_proj(attn_output)
567
+
568
+ if not output_attentions:
569
+ attn_weights = None
570
+
571
+ return attn_output, attn_weights, past_key_value
572
+
573
+ def _flash_attention_forward(
574
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
575
+ ):
576
+ """
577
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
578
+ first unpad the input, then computes the attention scores and pad the final attention scores.
579
+
580
+ Args:
581
+ query_states (`torch.Tensor`):
582
+ Input query states to be passed to Flash Attention API
583
+ key_states (`torch.Tensor`):
584
+ Input key states to be passed to Flash Attention API
585
+ value_states (`torch.Tensor`):
586
+ Input value states to be passed to Flash Attention API
587
+ attention_mask (`torch.Tensor`):
588
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
589
+ position of padding tokens and 1 for the position of non-padding tokens.
590
+ dropout (`int`, *optional*):
591
+ Attention dropout
592
+ softmax_scale (`float`, *optional*):
593
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
594
+ """
595
+ if not self._flash_attn_uses_top_left_mask:
596
+ causal = self.is_causal
597
+ else:
598
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
599
+ causal = self.is_causal and query_length != 1
600
+ # Contains at least one padding token in the sequence
601
+ if attention_mask is not None:
602
+ batch_size = query_states.shape[0]
603
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
604
+ query_states, key_states, value_states, attention_mask, query_length
605
+ )
606
+
607
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
608
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
609
+ attn_output_unpad = flash_attn_varlen_func(
610
+ query_states,
611
+ key_states,
612
+ value_states,
613
+ cu_seqlens_q=cu_seqlens_q,
614
+ cu_seqlens_k=cu_seqlens_k,
615
+ max_seqlen_q=max_seqlen_in_batch_q,
616
+ max_seqlen_k=max_seqlen_in_batch_k,
617
+ dropout_p=dropout,
618
+ softmax_scale=softmax_scale,
619
+ causal=causal,
620
+ )
621
+
622
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
623
+ else:
624
+ attn_output = flash_attn_func(
625
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
626
+ )
627
+
628
+ return attn_output
629
+
630
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
631
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
632
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
633
+
634
+ key_layer = index_first_axis(
635
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
636
+ )
637
+ value_layer = index_first_axis(
638
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
639
+ )
640
+ if query_length == kv_seq_len:
641
+ query_layer = index_first_axis(
642
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
643
+ )
644
+ cu_seqlens_q = cu_seqlens_k
645
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
646
+ indices_q = indices_k
647
+ elif query_length == 1:
648
+ max_seqlen_in_batch_q = 1
649
+ cu_seqlens_q = torch.arange(
650
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
651
+ ) # There is a memcpy here, that is very bad.
652
+ indices_q = cu_seqlens_q[:-1]
653
+ query_layer = query_layer.squeeze(1)
654
+ else:
655
+ # The -q_len: slice assumes left padding.
656
+ attention_mask = attention_mask[:, -query_length:]
657
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
658
+
659
+ return (
660
+ query_layer,
661
+ key_layer,
662
+ value_layer,
663
+ indices_q,
664
+ (cu_seqlens_q, cu_seqlens_k),
665
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
666
+ )
667
+
668
+
669
+ class MiniCPMSdpaAttention(MiniCPMAttention):
670
+ """
671
+ MiniCPM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
672
+ `MiniCPMAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
673
+ SDPA API.
674
+ """
675
+
676
+ # Adapted from MiniCPMAttention.forward
677
+ def forward(
678
+ self,
679
+ hidden_states: torch.Tensor,
680
+ attention_mask: Optional[torch.Tensor] = None,
681
+ position_ids: Optional[torch.LongTensor] = None,
682
+ past_key_value: Optional[Cache] = None,
683
+ output_attentions: bool = False,
684
+ use_cache: bool = False,
685
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
686
+ if output_attentions:
687
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
688
+ logger.warning_once(
689
+ "MiniCPMModel is using MiniCPMSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
690
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
691
+ )
692
+ return super().forward(
693
+ hidden_states=hidden_states,
694
+ attention_mask=attention_mask,
695
+ position_ids=position_ids,
696
+ past_key_value=past_key_value,
697
+ output_attentions=output_attentions,
698
+ use_cache=use_cache,
699
+ )
700
+
701
+ bsz, q_len, _ = hidden_states.size()
702
+
703
+ query_states = self.q_proj(hidden_states)
704
+ key_states = self.k_proj(hidden_states)
705
+ value_states = self.v_proj(hidden_states)
706
+
707
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
708
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
709
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
710
+
711
+ kv_seq_len = key_states.shape[-2]
712
+ if past_key_value is not None:
713
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
714
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
715
+
716
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
717
+
718
+ if past_key_value is not None:
719
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
720
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
721
+
722
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
723
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
724
+
725
+ if attention_mask is not None:
726
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
727
+ raise ValueError(
728
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
729
+ )
730
+
731
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
732
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
733
+ if query_states.device.type == "cuda" and attention_mask is not None:
734
+ query_states = query_states.contiguous()
735
+ key_states = key_states.contiguous()
736
+ value_states = value_states.contiguous()
737
+
738
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
739
+ query_states,
740
+ key_states,
741
+ value_states,
742
+ attn_mask=attention_mask,
743
+ dropout_p=self.attention_dropout if self.training else 0.0,
744
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
745
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
746
+ )
747
+
748
+ attn_output = attn_output.transpose(1, 2).contiguous()
749
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
750
+
751
+ attn_output = self.o_proj(attn_output)
752
+
753
+ return attn_output, None, past_key_value
754
+
755
+
756
+ MINICPM_ATTENTION_CLASSES = {
757
+ "eager": MiniCPMAttention,
758
+ "flash_attention_2": MiniCPMFlashAttention2,
759
+ "sdpa": MiniCPMSdpaAttention,
760
+ }
761
+
762
+
763
+ class MiniCPMDecoderLayer(nn.Module):
764
+ def __init__(self, config: MiniCPMConfig, layer_idx: int):
765
+ super().__init__()
766
+ self.hidden_size = config.hidden_size
767
+ self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
768
+
769
+ self.mlp = MiniCPMMLP(config)
770
+ self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
771
+ self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
772
+
773
+ self.scale_depth = config.scale_depth
774
+ self.num_hidden_layers = config.num_hidden_layers
775
+
776
+ def forward(
777
+ self,
778
+ hidden_states: torch.Tensor,
779
+ attention_mask: Optional[torch.Tensor] = None,
780
+ position_ids: Optional[torch.LongTensor] = None,
781
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
782
+ output_attentions: Optional[bool] = False,
783
+ use_cache: Optional[bool] = False,
784
+ **kwargs,
785
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
786
+ """
787
+ Args:
788
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
789
+ attention_mask (`torch.FloatTensor`, *optional*):
790
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
791
+ query_sequence_length, key_sequence_length)` if default attention is used.
792
+ output_attentions (`bool`, *optional*):
793
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
794
+ returned tensors for more detail.
795
+ use_cache (`bool`, *optional*):
796
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
797
+ (see `past_key_values`).
798
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
799
+ """
800
+ if "padding_mask" in kwargs:
801
+ warnings.warn(
802
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
803
+ )
804
+
805
+ residual = hidden_states
806
+ hidden_states = self.input_layernorm(hidden_states)
807
+ # Self Attention
808
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
809
+ hidden_states=hidden_states,
810
+ attention_mask=attention_mask,
811
+ position_ids=position_ids,
812
+ past_key_value=past_key_value,
813
+ output_attentions=output_attentions,
814
+ use_cache=use_cache,
815
+ **kwargs,
816
+ )
817
+
818
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
819
+
820
+ # Fully Connected
821
+ residual = hidden_states
822
+ hidden_states = self.post_attention_layernorm(hidden_states)
823
+
824
+ hidden_states = self.mlp(hidden_states)
825
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
826
+
827
+ outputs = (hidden_states,)
828
+
829
+ if output_attentions:
830
+ outputs += (self_attn_weights,)
831
+
832
+ if use_cache:
833
+ outputs += (present_key_value,)
834
+
835
+ return outputs
836
+
837
+
838
+ MINICPM_START_DOCSTRING = r"""
839
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
840
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
841
+ etc.)
842
+
843
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
844
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
845
+ and behavior.
846
+
847
+ Parameters:
848
+ config ([`MiniCPMConfig`]):
849
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
850
+ load the weights associated with the model, only the configuration. Check out the
851
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
852
+ """
853
+
854
+
855
+ @add_start_docstrings(
856
+ "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
857
+ MINICPM_START_DOCSTRING,
858
+ )
859
+ class MiniCPMPreTrainedModel(PreTrainedModel):
860
+ config_class = MiniCPMConfig
861
+ base_model_prefix = "model"
862
+ supports_gradient_checkpointing = True
863
+ _no_split_modules = ["MiniCPMDecoderLayer"]
864
+ _skip_keys_device_placement = "past_key_values"
865
+ _supports_flash_attn_2 = True
866
+ _supports_sdpa = True
867
+ _supports_cache_class = True
868
+
869
+ def _init_weights(self, module):
870
+ std = self.config.initializer_range
871
+ if isinstance(module, nn.Linear):
872
+ module.weight.data.normal_(mean=0.0, std=std)
873
+ if module.bias is not None:
874
+ module.bias.data.zero_()
875
+ elif isinstance(module, nn.Embedding):
876
+ module.weight.data.normal_(mean=0.0, std=std)
877
+ if module.padding_idx is not None:
878
+ module.weight.data[module.padding_idx].zero_()
879
+
880
+
881
+ MINICPM_INPUTS_DOCSTRING = r"""
882
+ Args:
883
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
884
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
885
+ it.
886
+
887
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
888
+ [`PreTrainedTokenizer.__call__`] for details.
889
+
890
+ [What are input IDs?](../glossary#input-ids)
891
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
892
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
893
+
894
+ - 1 for tokens that are **not masked**,
895
+ - 0 for tokens that are **masked**.
896
+
897
+ [What are attention masks?](../glossary#attention-mask)
898
+
899
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
900
+ [`PreTrainedTokenizer.__call__`] for details.
901
+
902
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
903
+ `past_key_values`).
904
+
905
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
906
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
907
+ information on the default strategy.
908
+
909
+ - 1 indicates the head is **not masked**,
910
+ - 0 indicates the head is **masked**.
911
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
912
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
913
+ config.n_positions - 1]`.
914
+
915
+ [What are position IDs?](../glossary#position-ids)
916
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
917
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
918
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
919
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
920
+
921
+ Two formats are allowed:
922
+ - a [`~cache_utils.Cache`] instance;
923
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
924
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
925
+ cache format.
926
+
927
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
928
+ legacy cache format will be returned.
929
+
930
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
931
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
932
+ of shape `(batch_size, sequence_length)`.
933
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
934
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
935
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
936
+ model's internal embedding lookup matrix.
937
+ use_cache (`bool`, *optional*):
938
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
939
+ `past_key_values`).
940
+ output_attentions (`bool`, *optional*):
941
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
942
+ tensors for more detail.
943
+ output_hidden_states (`bool`, *optional*):
944
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
945
+ more detail.
946
+ return_dict (`bool`, *optional*):
947
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
948
+ """
949
+
950
+
951
+ @add_start_docstrings(
952
+ "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
953
+ MINICPM_START_DOCSTRING,
954
+ )
955
+ class MiniCPMModel(MiniCPMPreTrainedModel):
956
+ """
957
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
958
+
959
+ Args:
960
+ config: MiniCPMConfig
961
+ """
962
+
963
+ def __init__(self, config: MiniCPMConfig):
964
+ super().__init__(config)
965
+ self.padding_idx = config.pad_token_id
966
+ self.vocab_size = config.vocab_size
967
+
968
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
969
+ self.layers = nn.ModuleList(
970
+ [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
971
+ )
972
+ self._use_sdpa = config._attn_implementation == "sdpa"
973
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
974
+
975
+ self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
976
+
977
+ self.gradient_checkpointing = False
978
+ # Initialize weights and apply final processing
979
+ self.post_init()
980
+
981
+ def get_input_embeddings(self):
982
+ return self.embed_tokens
983
+
984
+ def set_input_embeddings(self, value):
985
+ self.embed_tokens = value
986
+
987
+ @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
988
+ def forward(
989
+ self,
990
+ input_ids: torch.LongTensor = None,
991
+ attention_mask: Optional[torch.Tensor] = None,
992
+ position_ids: Optional[torch.LongTensor] = None,
993
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
994
+ inputs_embeds: Optional[torch.FloatTensor] = None,
995
+ use_cache: Optional[bool] = None,
996
+ output_attentions: Optional[bool] = None,
997
+ output_hidden_states: Optional[bool] = None,
998
+ return_dict: Optional[bool] = None,
999
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1000
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1001
+ output_hidden_states = (
1002
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1003
+ )
1004
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1005
+
1006
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1007
+
1008
+ # retrieve input_ids and inputs_embeds
1009
+ if input_ids is not None and inputs_embeds is not None:
1010
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1011
+ elif input_ids is not None:
1012
+ batch_size, seq_length = input_ids.shape[:2]
1013
+ elif inputs_embeds is not None:
1014
+ batch_size, seq_length = inputs_embeds.shape[:2]
1015
+ else:
1016
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1017
+
1018
+ if self.gradient_checkpointing and self.training:
1019
+ if use_cache:
1020
+ logger.warning_once(
1021
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1022
+ )
1023
+ use_cache = False
1024
+
1025
+ past_key_values_length = 0
1026
+ if use_cache:
1027
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1028
+ if use_legacy_cache:
1029
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1030
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1031
+
1032
+ if position_ids is None:
1033
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1034
+ position_ids = torch.arange(
1035
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1036
+ )
1037
+ position_ids = position_ids.unsqueeze(0)
1038
+
1039
+ if inputs_embeds is None:
1040
+ inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
1041
+
1042
+
1043
+ if self._use_flash_attention_2:
1044
+ # 2d mask is passed through the layers
1045
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1046
+ elif self._use_sdpa and not output_attentions:
1047
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1048
+ # the manual implementation that requires a 4D causal mask in all cases.
1049
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1050
+ attention_mask,
1051
+ (batch_size, seq_length),
1052
+ inputs_embeds,
1053
+ past_key_values_length,
1054
+ )
1055
+ else:
1056
+ # 4d mask is passed through the layers
1057
+ attention_mask = _prepare_4d_causal_attention_mask(
1058
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1059
+ )
1060
+
1061
+ # embed positions
1062
+ hidden_states = inputs_embeds
1063
+
1064
+ # decoder layers
1065
+ all_hidden_states = () if output_hidden_states else None
1066
+ all_self_attns = () if output_attentions else None
1067
+ next_decoder_cache = None
1068
+
1069
+ for decoder_layer in self.layers:
1070
+ if output_hidden_states:
1071
+ all_hidden_states += (hidden_states,)
1072
+
1073
+ if self.gradient_checkpointing and self.training:
1074
+ layer_outputs = self._gradient_checkpointing_func(
1075
+ decoder_layer.__call__,
1076
+ hidden_states,
1077
+ attention_mask,
1078
+ position_ids,
1079
+ past_key_values,
1080
+ output_attentions,
1081
+ use_cache,
1082
+ )
1083
+ else:
1084
+ layer_outputs = decoder_layer(
1085
+ hidden_states,
1086
+ attention_mask=attention_mask,
1087
+ position_ids=position_ids,
1088
+ past_key_value=past_key_values,
1089
+ output_attentions=output_attentions,
1090
+ use_cache=use_cache,
1091
+ )
1092
+
1093
+ hidden_states = layer_outputs[0]
1094
+
1095
+ if use_cache:
1096
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1097
+
1098
+ if output_attentions:
1099
+ all_self_attns += (layer_outputs[1],)
1100
+
1101
+ hidden_states = self.norm(hidden_states)
1102
+
1103
+ # add hidden states from the last decoder layer
1104
+ if output_hidden_states:
1105
+ all_hidden_states += (hidden_states,)
1106
+
1107
+ next_cache = None
1108
+ if use_cache:
1109
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1110
+ if not return_dict:
1111
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1112
+ return BaseModelOutputWithPast(
1113
+ last_hidden_state=hidden_states,
1114
+ past_key_values=next_cache,
1115
+ hidden_states=all_hidden_states,
1116
+ attentions=all_self_attns,
1117
+ )
1118
+
1119
+
1120
+ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1121
+ _tied_weights_keys = ["lm_head.weight"]
1122
+
1123
+ def __init__(self, config):
1124
+ super().__init__(config)
1125
+ self.model = MiniCPMModel(config)
1126
+ self.vocab_size = config.vocab_size
1127
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1128
+
1129
+ # Initialize weights and apply final processing
1130
+ self.post_init()
1131
+
1132
+ def get_input_embeddings(self):
1133
+ return self.model.embed_tokens
1134
+
1135
+ def set_input_embeddings(self, value):
1136
+ self.model.embed_tokens = value
1137
+
1138
+ def get_output_embeddings(self):
1139
+ return self.lm_head
1140
+
1141
+ def set_output_embeddings(self, new_embeddings):
1142
+ self.lm_head = new_embeddings
1143
+
1144
+ def set_decoder(self, decoder):
1145
+ self.model = decoder
1146
+
1147
+ def get_decoder(self):
1148
+ return self.model
1149
+
1150
+ @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1151
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1152
+ def forward(
1153
+ self,
1154
+ input_ids: torch.LongTensor = None,
1155
+ attention_mask: Optional[torch.Tensor] = None,
1156
+ position_ids: Optional[torch.LongTensor] = None,
1157
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1158
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1159
+ labels: Optional[torch.LongTensor] = None,
1160
+ use_cache: Optional[bool] = None,
1161
+ output_attentions: Optional[bool] = None,
1162
+ output_hidden_states: Optional[bool] = None,
1163
+ return_dict: Optional[bool] = None,
1164
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1165
+ r"""
1166
+ Args:
1167
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1168
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1169
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1170
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1171
+
1172
+ Returns:
1173
+
1174
+ Example:
1175
+
1176
+ ```python
1177
+ >>> from transformers import AutoTokenizer, MiniCPMForCausalLM
1178
+
1179
+ >>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1180
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1181
+
1182
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1183
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1184
+
1185
+ >>> # Generate
1186
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1187
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1188
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1189
+ ```"""
1190
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1191
+ output_hidden_states = (
1192
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1193
+ )
1194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1195
+
1196
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1197
+ outputs = self.model(
1198
+ input_ids=input_ids,
1199
+ attention_mask=attention_mask,
1200
+ position_ids=position_ids,
1201
+ past_key_values=past_key_values,
1202
+ inputs_embeds=inputs_embeds,
1203
+ use_cache=use_cache,
1204
+ output_attentions=output_attentions,
1205
+ output_hidden_states=output_hidden_states,
1206
+ return_dict=return_dict,
1207
+ )
1208
+
1209
+ hidden_states = outputs[0]
1210
+ if self.config.pretraining_tp > 1:
1211
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1212
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1213
+ logits = torch.cat(logits, dim=-1)
1214
+ else:
1215
+ logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))
1216
+ logits = logits.float()
1217
+
1218
+ loss = None
1219
+ if labels is not None:
1220
+ # Shift so that tokens < n predict n
1221
+ shift_logits = logits[..., :-1, :].contiguous()
1222
+ shift_labels = labels[..., 1:].contiguous()
1223
+ # Flatten the tokens
1224
+ loss_fct = CrossEntropyLoss()
1225
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1226
+ shift_labels = shift_labels.view(-1)
1227
+ # Enable model parallelism
1228
+ shift_labels = shift_labels.to(shift_logits.device)
1229
+ loss = loss_fct(shift_logits, shift_labels)
1230
+
1231
+ if not return_dict:
1232
+ output = (logits,) + outputs[1:]
1233
+ return (loss,) + output if loss is not None else output
1234
+
1235
+ return CausalLMOutputWithPast(
1236
+ loss=loss,
1237
+ logits=logits,
1238
+ past_key_values=outputs.past_key_values,
1239
+ hidden_states=outputs.hidden_states,
1240
+ attentions=outputs.attentions,
1241
+ )
1242
+
1243
+ def prepare_inputs_for_generation(
1244
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1245
+ ):
1246
+ if past_key_values is not None:
1247
+ if isinstance(past_key_values, Cache):
1248
+ cache_length = past_key_values.get_seq_length()
1249
+ past_length = past_key_values.seen_tokens
1250
+ max_cache_length = past_key_values.get_max_length()
1251
+ else:
1252
+ cache_length = past_length = past_key_values[0][0].shape[2]
1253
+ max_cache_length = None
1254
+
1255
+ # Keep only the unprocessed tokens:
1256
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1257
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1258
+ # input)
1259
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1260
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1261
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1262
+ # input_ids based on the past_length.
1263
+ elif past_length < input_ids.shape[1]:
1264
+ input_ids = input_ids[:, past_length:]
1265
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1266
+ else:
1267
+ remove_prefix_length = input_ids.shape[1] - 1
1268
+ input_ids = input_ids[:, remove_prefix_length:]
1269
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1270
+ if (
1271
+ max_cache_length is not None
1272
+ and attention_mask is not None
1273
+ and cache_length + input_ids.shape[1] > max_cache_length
1274
+ ):
1275
+ attention_mask = attention_mask[:, -max_cache_length:]
1276
+
1277
+ position_ids = kwargs.get("position_ids", None)
1278
+ if attention_mask is not None and position_ids is None:
1279
+ # create position_ids on the fly for batch generation
1280
+ position_ids = attention_mask.long().cumsum(-1) - 1
1281
+ position_ids.masked_fill_(attention_mask == 0, 1)
1282
+ if past_key_values:
1283
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1284
+
1285
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1286
+ if inputs_embeds is not None and past_key_values is None:
1287
+ model_inputs = {"inputs_embeds": inputs_embeds}
1288
+ else:
1289
+ model_inputs = {"input_ids": input_ids}
1290
+
1291
+ model_inputs.update(
1292
+ {
1293
+ "position_ids": position_ids,
1294
+ "past_key_values": past_key_values,
1295
+ "use_cache": kwargs.get("use_cache"),
1296
+ "attention_mask": attention_mask,
1297
+ }
1298
+ )
1299
+ return model_inputs
1300
+
1301
+ @staticmethod
1302
+ def _reorder_cache(past_key_values, beam_idx):
1303
+ reordered_past = ()
1304
+ for layer_past in past_key_values:
1305
+ reordered_past += (
1306
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1307
+ )
1308
+ return reordered_past
1309
+
1310
+ @torch.inference_mode()
1311
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1312
+ max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
1313
+ **kwargs):
1314
+ if history is None:
1315
+ history = []
1316
+ if logits_processor:
1317
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1318
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1319
+ else:
1320
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1321
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1322
+
1323
+ history.append({"role": role, "content": query})
1324
+ history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
1325
+ inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
1326
+ outputs = self.generate(**inputs, **gen_kwargs)
1327
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1328
+ response = tokenizer.decode(outputs)
1329
+ pattern = re.compile(r".*?(?=<AI>|<用户>)", re.DOTALL)
1330
+ matches = pattern.findall(response)
1331
+ if len(matches) > 0:
1332
+ response = matches[0]
1333
+ history.append({"role": "assistant", "content": response})
1334
+ return response, history
1335
+
1336
+
1337
+ @add_start_docstrings(
1338
+ """
1339
+ The MiniCPM Model transformer with a sequence classification head on top (linear layer).
1340
+
1341
+ [`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1342
+ (e.g. GPT-2) do.
1343
+
1344
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1345
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1346
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1347
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1348
+ each row of the batch).
1349
+ """,
1350
+ MINICPM_START_DOCSTRING,
1351
+ )
1352
+ class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
1353
+ def __init__(self, config):
1354
+ super().__init__(config)
1355
+ self.num_labels = config.num_labels
1356
+ self.model = MiniCPMModel(config)
1357
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1358
+
1359
+ # Initialize weights and apply final processing
1360
+ self.post_init()
1361
+
1362
+ def get_input_embeddings(self):
1363
+ return self.model.embed_tokens
1364
+
1365
+ def set_input_embeddings(self, value):
1366
+ self.model.embed_tokens = value
1367
+
1368
+ @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1369
+ def forward(
1370
+ self,
1371
+ input_ids: torch.LongTensor = None,
1372
+ attention_mask: Optional[torch.Tensor] = None,
1373
+ position_ids: Optional[torch.LongTensor] = None,
1374
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1375
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1376
+ labels: Optional[torch.LongTensor] = None,
1377
+ use_cache: Optional[bool] = None,
1378
+ output_attentions: Optional[bool] = None,
1379
+ output_hidden_states: Optional[bool] = None,
1380
+ return_dict: Optional[bool] = None,
1381
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1382
+ r"""
1383
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1384
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1385
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1386
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1387
+ """
1388
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1389
+
1390
+ transformer_outputs = self.model(
1391
+ input_ids,
1392
+ attention_mask=attention_mask,
1393
+ position_ids=position_ids,
1394
+ past_key_values=past_key_values,
1395
+ inputs_embeds=inputs_embeds,
1396
+ use_cache=use_cache,
1397
+ output_attentions=output_attentions,
1398
+ output_hidden_states=output_hidden_states,
1399
+ return_dict=return_dict,
1400
+ )
1401
+ hidden_states = transformer_outputs[0]
1402
+ logits = self.score(hidden_states)
1403
+
1404
+ if input_ids is not None:
1405
+ batch_size = input_ids.shape[0]
1406
+ else:
1407
+ batch_size = inputs_embeds.shape[0]
1408
+
1409
+ if self.config.pad_token_id is None and batch_size != 1:
1410
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1411
+ if self.config.pad_token_id is None:
1412
+ sequence_lengths = -1
1413
+ else:
1414
+ if input_ids is not None:
1415
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1416
+ logits.device
1417
+ )
1418
+ else:
1419
+ sequence_lengths = -1
1420
+
1421
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1422
+
1423
+ loss = None
1424
+ if labels is not None:
1425
+ labels = labels.to(logits.device)
1426
+ if self.config.problem_type is None:
1427
+ if self.num_labels == 1:
1428
+ self.config.problem_type = "regression"
1429
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1430
+ self.config.problem_type = "single_label_classification"
1431
+ else:
1432
+ self.config.problem_type = "multi_label_classification"
1433
+
1434
+ if self.config.problem_type == "regression":
1435
+ loss_fct = MSELoss()
1436
+ if self.num_labels == 1:
1437
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1438
+ else:
1439
+ loss = loss_fct(pooled_logits, labels)
1440
+ elif self.config.problem_type == "single_label_classification":
1441
+ loss_fct = CrossEntropyLoss()
1442
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1443
+ elif self.config.problem_type == "multi_label_classification":
1444
+ loss_fct = BCEWithLogitsLoss()
1445
+ loss = loss_fct(pooled_logits, labels)
1446
+ if not return_dict:
1447
+ output = (pooled_logits,) + transformer_outputs[1:]
1448
+ return ((loss,) + output) if loss is not None else output
1449
+
1450
+ return SequenceClassifierOutputWithPast(
1451
+ loss=loss,
1452
+ logits=pooled_logits,
1453
+ past_key_values=transformer_outputs.past_key_values,
1454
+ hidden_states=transformer_outputs.hidden_states,
1455
+ attentions=transformer_outputs.attentions,
1456
+ )
Unicorn/bunny/model/language_model/phi/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Microsoft and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ from transformers.utils import (
19
+ OptionalDependencyNotAvailable,
20
+ _LazyModule,
21
+ is_sentencepiece_available,
22
+ is_tokenizers_available,
23
+ is_torch_available,
24
+ )
25
+
26
+
27
+ _import_structure = {
28
+ "configuration_phi": ["PHI_PRETRAINED_CONFIG_ARCHIVE_MAP", "PhiConfig"],
29
+ }
30
+
31
+ try:
32
+ if not is_torch_available():
33
+ raise OptionalDependencyNotAvailable()
34
+ except OptionalDependencyNotAvailable:
35
+ pass
36
+ else:
37
+ _import_structure["modeling_phi"] = [
38
+ "PHI_PRETRAINED_MODEL_ARCHIVE_LIST",
39
+ "PhiPreTrainedModel",
40
+ "PhiModel",
41
+ "PhiForCausalLM",
42
+ "PhiForSequenceClassification",
43
+ "PhiForTokenClassification",
44
+ ]
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from .configuration_phi import PHI_PRETRAINED_CONFIG_ARCHIVE_MAP, PhiConfig
49
+
50
+ try:
51
+ if not is_torch_available():
52
+ raise OptionalDependencyNotAvailable()
53
+ except OptionalDependencyNotAvailable:
54
+ pass
55
+ else:
56
+ from .modeling_phi import (
57
+ PHI_PRETRAINED_MODEL_ARCHIVE_LIST,
58
+ PhiForCausalLM,
59
+ PhiForSequenceClassification,
60
+ PhiForTokenClassification,
61
+ PhiModel,
62
+ PhiPreTrainedModel,
63
+ )
64
+
65
+
66
+ else:
67
+ import sys
68
+
69
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Unicorn/bunny/model/language_model/phi/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.05 kB). View file
 
Unicorn/bunny/model/language_model/phi/__pycache__/configuration_phi.cpython-310.pyc ADDED
Binary file (8.02 kB). View file
 
Unicorn/bunny/model/language_model/phi/__pycache__/modeling_phi.cpython-310.pyc ADDED
Binary file (39.8 kB). View file
 
Unicorn/bunny/model/language_model/phi/configuration_phi.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Phi model configuration"""
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/resolve/main/config.json",
27
+ "microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/resolve/main/config.json",
28
+ "microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
29
+ }
30
+
31
+
32
+ class PhiConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of the Phi
37
+ [microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Args:
43
+ vocab_size (`int`, *optional*, defaults to 51200):
44
+ Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the
45
+ `inputs_ids` passed when calling [`PhiModel`].
46
+ hidden_size (`int`, *optional*, defaults to 2048):
47
+ Dimension of the hidden representations.
48
+ intermediate_size (`int`, *optional*, defaults to 8192):
49
+ Dimension of the MLP representations.
50
+ num_hidden_layers (`int`, *optional*, defaults to 24):
51
+ Number of hidden layers in the Transformer decoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 32):
53
+ Number of attention heads for each attention layer in the Transformer decoder.
54
+ num_key_value_heads (`int`, *optional*):
55
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
+ by meanpooling all the original heads within that group. For more details checkout [this
60
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
61
+ `num_attention_heads`.
62
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
63
+ Dropout probability for mlp outputs.
64
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
65
+ The dropout ratio for the embeddings.
66
+ attention_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio after computing the attention scores.
68
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
69
+ The non-linear activation function (function or string) in the decoder.
70
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
71
+ The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
72
+ tokens.
73
+ initializer_range (`float`, *optional*, defaults to 0.02):
74
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
76
+ The epsilon used by the rms normalization layers.
77
+ use_cache (`bool`, *optional*, defaults to `True`):
78
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
79
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
80
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
81
+ Whether to tie weight embeddings
82
+ rope_theta (`float`, *optional*, defaults to 10000.0):
83
+ The base period of the RoPE embeddings.
84
+ rope_scaling (`Dict`, *optional*):
85
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
86
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
87
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
88
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
89
+ these scaling strategies behave:
90
+ https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
91
+ is an experimental feature, subject to breaking API changes in future versions.
92
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
93
+ Percentage of the query and keys which will have rotary embedding.
94
+ qk_layernorm (`bool`, *optional*, defaults to `False`):
95
+ Whether or not to normalize the Queries and Keys after projecting the hidden states.
96
+ bos_token_id (`int`, *optional*, defaults to 1):
97
+ Denotes beginning of sequences token id.
98
+ eos_token_id (`int`, *optional*, defaults to 2):
99
+ Denotes end of sequences token id.
100
+
101
+ Example:
102
+
103
+ ```python
104
+ >>> from transformers import PhiModel, PhiConfig
105
+
106
+ >>> # Initializing a Phi-1 style configuration
107
+ >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
108
+
109
+ >>> # Initializing a model from the configuration
110
+ >>> model = PhiModel(configuration)
111
+
112
+ >>> # Accessing the model configuration
113
+ >>> configuration = model.config
114
+ ```"""
115
+
116
+ model_type = "phi"
117
+ keys_to_ignore_at_inference = ["past_key_values"]
118
+
119
+ def __init__(
120
+ self,
121
+ vocab_size=51200,
122
+ hidden_size=2048,
123
+ intermediate_size=8192,
124
+ num_hidden_layers=24,
125
+ num_attention_heads=32,
126
+ num_key_value_heads=None,
127
+ resid_pdrop=0.0,
128
+ embd_pdrop=0.0,
129
+ attention_dropout=0.0,
130
+ hidden_act="gelu_new",
131
+ max_position_embeddings=2048,
132
+ initializer_range=0.02,
133
+ layer_norm_eps=1e-5,
134
+ use_cache=True,
135
+ tie_word_embeddings=False,
136
+ rope_theta=10000.0,
137
+ rope_scaling=None,
138
+ partial_rotary_factor=0.5,
139
+ qk_layernorm=False,
140
+ bos_token_id=1,
141
+ eos_token_id=2,
142
+ **kwargs,
143
+ ):
144
+ self.vocab_size = vocab_size
145
+ self.hidden_size = hidden_size
146
+ self.intermediate_size = intermediate_size
147
+ self.num_hidden_layers = num_hidden_layers
148
+ self.num_attention_heads = num_attention_heads
149
+
150
+ if num_key_value_heads is None:
151
+ num_key_value_heads = num_attention_heads
152
+
153
+ self.num_key_value_heads = num_key_value_heads
154
+ self.resid_pdrop = resid_pdrop
155
+ self.embd_pdrop = embd_pdrop
156
+ self.attention_dropout = attention_dropout
157
+ self.hidden_act = hidden_act
158
+ self.max_position_embeddings = max_position_embeddings
159
+ self.initializer_range = initializer_range
160
+ self.layer_norm_eps = layer_norm_eps
161
+ self.use_cache = use_cache
162
+ self.rope_theta = rope_theta
163
+ self.rope_scaling = rope_scaling
164
+ self.partial_rotary_factor = partial_rotary_factor
165
+ self.qk_layernorm = qk_layernorm
166
+ self._rope_scaling_validation()
167
+
168
+ super().__init__(
169
+ bos_token_id=bos_token_id,
170
+ eos_token_id=eos_token_id,
171
+ tie_word_embeddings=tie_word_embeddings,
172
+ **kwargs,
173
+ )
174
+
175
+ # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
176
+ def _rope_scaling_validation(self):
177
+ """
178
+ Validate the `rope_scaling` configuration.
179
+ """
180
+ if self.rope_scaling is None:
181
+ return
182
+
183
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
184
+ raise ValueError(
185
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
186
+ f"got {self.rope_scaling}"
187
+ )
188
+ rope_scaling_type = self.rope_scaling.get("type", None)
189
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
190
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
191
+ raise ValueError(
192
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
193
+ )
194
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
195
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
Unicorn/bunny/model/language_model/phi/modeling_phi.py ADDED
@@ -0,0 +1,1374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch Phi model."""
17
+
18
+
19
+ import math
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPast,
33
+ CausalLMOutputWithPast,
34
+ SequenceClassifierOutputWithPast,
35
+ TokenClassifierOutput,
36
+ )
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import (
39
+ add_code_sample_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ is_flash_attn_2_available,
43
+ is_flash_attn_greater_or_equal_2_10,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_phi import PhiConfig
48
+
49
+
50
+ if is_flash_attn_2_available():
51
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
52
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CHECKPOINT_FOR_DOC = "microsoft/phi-1"
58
+ _CONFIG_FOR_DOC = "PhiConfig"
59
+
60
+ PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
61
+ "microsoft/phi-1",
62
+ "microsoft/phi-1_5",
63
+ "microsoft/phi-2",
64
+ # See all Phi models at https://huggingface.co/models?filter=phi
65
+ ]
66
+
67
+
68
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
69
+ def _get_unpad_data(attention_mask):
70
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
71
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
72
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
73
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
74
+ return (
75
+ indices,
76
+ cu_seqlens,
77
+ max_seqlen_in_batch,
78
+ )
79
+
80
+
81
+ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi
82
+ class PhiRotaryEmbedding(nn.Module):
83
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
84
+ super().__init__()
85
+
86
+ self.dim = dim
87
+ self.max_position_embeddings = max_position_embeddings
88
+ self.base = base
89
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
90
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
91
+
92
+ # Build here to make `torch.jit.trace` work.
93
+ self._set_cos_sin_cache(
94
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
95
+ )
96
+
97
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
98
+ self.max_seq_len_cached = seq_len
99
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
100
+
101
+ freqs = torch.outer(t, self.inv_freq)
102
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
103
+ emb = torch.cat((freqs, freqs), dim=-1)
104
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
105
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
106
+
107
+ def forward(self, x, seq_len=None):
108
+ # x: [bs, num_attention_heads, seq_len, head_size]
109
+ if seq_len > self.max_seq_len_cached:
110
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
111
+
112
+ return (
113
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
114
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
115
+ )
116
+
117
+
118
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
119
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
120
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
121
+
122
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
123
+ self.scaling_factor = scaling_factor
124
+ super().__init__(dim, max_position_embeddings, base, device)
125
+
126
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
127
+ self.max_seq_len_cached = seq_len
128
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
129
+ t = t / self.scaling_factor
130
+
131
+ freqs = torch.outer(t, self.inv_freq)
132
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
133
+ emb = torch.cat((freqs, freqs), dim=-1)
134
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
135
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
136
+
137
+
138
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
139
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
140
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
141
+
142
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
143
+ self.scaling_factor = scaling_factor
144
+ super().__init__(dim, max_position_embeddings, base, device)
145
+
146
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
147
+ self.max_seq_len_cached = seq_len
148
+
149
+ if seq_len > self.max_position_embeddings:
150
+ base = self.base * (
151
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
152
+ ) ** (self.dim / (self.dim - 2))
153
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
154
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
155
+
156
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
157
+
158
+ freqs = torch.outer(t, self.inv_freq)
159
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
162
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
163
+
164
+
165
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
166
+ def rotate_half(x):
167
+ """Rotates half the hidden dims of the input."""
168
+ x1 = x[..., : x.shape[-1] // 2]
169
+ x2 = x[..., x.shape[-1] // 2 :]
170
+ return torch.cat((-x2, x1), dim=-1)
171
+
172
+
173
+ # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
174
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
175
+ """Applies Rotary Position Embedding to the query and key tensors.
176
+
177
+ Args:
178
+ q (`torch.Tensor`): The query tensor.
179
+ k (`torch.Tensor`): The key tensor.
180
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
181
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
182
+ position_ids (`torch.Tensor`):
183
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
184
+ used to pass offsetted position ids when working with a KV-cache.
185
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
186
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
187
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
188
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
189
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
190
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
191
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
192
+ Returns:
193
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
194
+ """
195
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
196
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
197
+ q_embed = (q * cos) + (rotate_half(q) * sin)
198
+ k_embed = (k * cos) + (rotate_half(k) * sin)
199
+ return q_embed, k_embed
200
+
201
+
202
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
203
+ class PhiMLP(nn.Module):
204
+ def __init__(self, config):
205
+ super().__init__()
206
+ self.config = config
207
+ self.activation_fn = ACT2FN[config.hidden_act]
208
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
209
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
210
+
211
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
212
+ hidden_states = self.fc1(hidden_states)
213
+ hidden_states = self.activation_fn(hidden_states)
214
+ hidden_states = self.fc2(hidden_states)
215
+ return hidden_states
216
+
217
+
218
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
219
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
220
+ """
221
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
222
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
223
+ """
224
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
225
+ if n_rep == 1:
226
+ return hidden_states
227
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
228
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
229
+
230
+
231
+ class PhiAttention(nn.Module):
232
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
233
+
234
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
235
+ super().__init__()
236
+ self.config = config
237
+ self.layer_idx = layer_idx
238
+ if layer_idx is None:
239
+ logger.warning_once(
240
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
241
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
242
+ "when creating this class."
243
+ )
244
+
245
+ self.attention_dropout = config.attention_dropout
246
+ self.hidden_size = config.hidden_size
247
+ self.num_heads = config.num_attention_heads
248
+ self.head_dim = self.hidden_size // self.num_heads
249
+ self.num_key_value_heads = config.num_key_value_heads
250
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
251
+ self.max_position_embeddings = config.max_position_embeddings
252
+ self.rope_theta = config.rope_theta
253
+ self.partial_rotary_factor = config.partial_rotary_factor
254
+ self.is_causal = True
255
+
256
+ if (self.head_dim * self.num_heads) != self.hidden_size:
257
+ raise ValueError(
258
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
259
+ f" and `num_heads`: {self.num_heads})."
260
+ )
261
+
262
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
263
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
264
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
265
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
266
+
267
+ self.qk_layernorm = config.qk_layernorm
268
+ if self.qk_layernorm:
269
+ self.q_layernorm = nn.LayerNorm(
270
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
271
+ )
272
+ self.k_layernorm = nn.LayerNorm(
273
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
274
+ )
275
+
276
+ self._init_rope()
277
+
278
+ def _init_rope(self):
279
+ if self.config.rope_scaling is None:
280
+ self.rotary_emb = PhiRotaryEmbedding(
281
+ int(self.partial_rotary_factor * self.head_dim),
282
+ max_position_embeddings=self.max_position_embeddings,
283
+ base=self.rope_theta,
284
+ )
285
+ else:
286
+ scaling_type = self.config.rope_scaling["type"]
287
+ scaling_factor = self.config.rope_scaling["factor"]
288
+ if scaling_type == "linear":
289
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
290
+ int(self.partial_rotary_factor * self.head_dim),
291
+ max_position_embeddings=self.max_position_embeddings,
292
+ scaling_factor=scaling_factor,
293
+ base=self.rope_theta,
294
+ )
295
+ elif scaling_type == "dynamic":
296
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
297
+ int(self.partial_rotary_factor * self.head_dim),
298
+ max_position_embeddings=self.max_position_embeddings,
299
+ scaling_factor=scaling_factor,
300
+ base=self.rope_theta,
301
+ )
302
+ else:
303
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states: torch.Tensor,
308
+ attention_mask: Optional[torch.Tensor] = None,
309
+ position_ids: Optional[torch.LongTensor] = None,
310
+ past_key_value: Optional[Cache] = None,
311
+ output_attentions: bool = False,
312
+ use_cache: bool = False,
313
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
314
+ bsz, q_len, _ = hidden_states.size()
315
+
316
+ query_states = self.q_proj(hidden_states)
317
+ key_states = self.k_proj(hidden_states)
318
+ value_states = self.v_proj(hidden_states)
319
+
320
+ if self.qk_layernorm:
321
+ query_states = self.q_layernorm(query_states)
322
+ key_states = self.k_layernorm(key_states)
323
+
324
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
325
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
326
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
327
+
328
+ kv_seq_len = key_states.shape[-2]
329
+ if past_key_value is not None:
330
+ if self.layer_idx is None:
331
+ raise ValueError(
332
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
333
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
334
+ "with a layer index."
335
+ )
336
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
337
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
338
+
339
+ # Partial rotary embedding
340
+ query_rot, query_pass = (
341
+ query_states[..., : self.rotary_emb.dim],
342
+ query_states[..., self.rotary_emb.dim :],
343
+ )
344
+ key_rot, key_pass = (
345
+ key_states[..., : self.rotary_emb.dim],
346
+ key_states[..., self.rotary_emb.dim :],
347
+ )
348
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
349
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
350
+
351
+ # [batch_size, seq_length, num_heads, head_dim]
352
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
353
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
354
+
355
+ if past_key_value is not None:
356
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
357
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
358
+
359
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
360
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
361
+
362
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
363
+ attn_weights = torch.matmul(
364
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
365
+ ) / math.sqrt(self.head_dim)
366
+
367
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
368
+ raise ValueError(
369
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
370
+ f" {attn_weights.size()}"
371
+ )
372
+
373
+ if attention_mask is not None:
374
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
375
+ raise ValueError(
376
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
377
+ )
378
+ attn_weights = attn_weights + attention_mask
379
+
380
+ # upcast attention to fp32
381
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
382
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
383
+
384
+ attn_output = torch.matmul(attn_weights, value_states)
385
+
386
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
387
+ raise ValueError(
388
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
389
+ f" {attn_output.size()}"
390
+ )
391
+
392
+ attn_output = attn_output.transpose(1, 2).contiguous()
393
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
394
+
395
+ attn_output = self.dense(attn_output)
396
+
397
+ if not output_attentions:
398
+ attn_weights = None
399
+
400
+ return attn_output, attn_weights, past_key_value
401
+
402
+
403
+ class PhiFlashAttention2(PhiAttention):
404
+ """
405
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
406
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
407
+ flash attention and deal with padding tokens in case the input contains any of them.
408
+ """
409
+
410
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
411
+ def __init__(self, *args, **kwargs):
412
+ super().__init__(*args, **kwargs)
413
+
414
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
415
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
416
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
417
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
418
+
419
+ def forward(
420
+ self,
421
+ hidden_states: torch.Tensor,
422
+ attention_mask: Optional[torch.LongTensor] = None,
423
+ position_ids: Optional[torch.LongTensor] = None,
424
+ past_key_value: Optional[Cache] = None,
425
+ output_attentions: bool = False,
426
+ use_cache: bool = False,
427
+ **kwargs,
428
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
429
+ # PhiFlashAttention2 attention does not support output_attentions
430
+
431
+ output_attentions = False
432
+
433
+ bsz, q_len, _ = hidden_states.size()
434
+
435
+ query_states = self.q_proj(hidden_states)
436
+ key_states = self.k_proj(hidden_states)
437
+ value_states = self.v_proj(hidden_states)
438
+
439
+ if self.qk_layernorm:
440
+ query_states = self.q_layernorm(query_states)
441
+ key_states = self.k_layernorm(key_states)
442
+
443
+ # Flash attention requires the input to have the shape
444
+ # batch_size x seq_length x head_dim x hidden_dim
445
+ # therefore we just need to keep the original shape
446
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
447
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
448
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
449
+
450
+ kv_seq_len = key_states.shape[-2]
451
+ if past_key_value is not None:
452
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
453
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
454
+
455
+ # Partial rotary embedding
456
+ query_rot, query_pass = (
457
+ query_states[..., : self.rotary_emb.dim],
458
+ query_states[..., self.rotary_emb.dim :],
459
+ )
460
+ key_rot, key_pass = (
461
+ key_states[..., : self.rotary_emb.dim],
462
+ key_states[..., self.rotary_emb.dim :],
463
+ )
464
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
465
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
466
+
467
+ # [batch_size, seq_length, num_heads, head_dim]
468
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
469
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
470
+
471
+ if past_key_value is not None:
472
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
473
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
474
+
475
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
476
+ # to be able to avoid many of these transpose/reshape/view.
477
+ query_states = query_states.transpose(1, 2)
478
+ key_states = key_states.transpose(1, 2)
479
+ value_states = value_states.transpose(1, 2)
480
+
481
+ attn_dropout = self.attention_dropout if self.training else 0.0
482
+
483
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
484
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
485
+ # cast them back in the correct dtype just to be sure everything works as expected.
486
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
487
+ # in fp32.
488
+
489
+ if query_states.dtype == torch.float32:
490
+ if torch.is_autocast_enabled():
491
+ target_dtype = torch.get_autocast_gpu_dtype()
492
+ # Handle the case where the model is quantized
493
+ elif hasattr(self.config, "_pre_quantization_dtype"):
494
+ target_dtype = self.config._pre_quantization_dtype
495
+ else:
496
+ target_dtype = self.q_proj.weight.dtype
497
+
498
+ logger.warning_once(
499
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
500
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
501
+ f" {target_dtype}."
502
+ )
503
+
504
+ query_states = query_states.to(target_dtype)
505
+ key_states = key_states.to(target_dtype)
506
+ value_states = value_states.to(target_dtype)
507
+
508
+ attn_output = self._flash_attention_forward(
509
+ query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
510
+ )
511
+
512
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
513
+ attn_output = self.dense(attn_output)
514
+
515
+ if not output_attentions:
516
+ attn_weights = None
517
+
518
+ return attn_output, attn_weights, past_key_value
519
+
520
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
521
+ def _flash_attention_forward(
522
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
523
+ ):
524
+ """
525
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
526
+ first unpad the input, then computes the attention scores and pad the final attention scores.
527
+
528
+ Args:
529
+ query_states (`torch.Tensor`):
530
+ Input query states to be passed to Flash Attention API
531
+ key_states (`torch.Tensor`):
532
+ Input key states to be passed to Flash Attention API
533
+ value_states (`torch.Tensor`):
534
+ Input value states to be passed to Flash Attention API
535
+ attention_mask (`torch.Tensor`):
536
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
537
+ position of padding tokens and 1 for the position of non-padding tokens.
538
+ dropout (`int`, *optional*):
539
+ Attention dropout
540
+ softmax_scale (`float`, *optional*):
541
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
542
+ """
543
+ if not self._flash_attn_uses_top_left_mask:
544
+ causal = self.is_causal
545
+ else:
546
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
547
+ causal = self.is_causal and query_length != 1
548
+
549
+ # Contains at least one padding token in the sequence
550
+ if attention_mask is not None:
551
+ batch_size = query_states.shape[0]
552
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
553
+ query_states, key_states, value_states, attention_mask, query_length
554
+ )
555
+
556
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
557
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
558
+
559
+ attn_output_unpad = flash_attn_varlen_func(
560
+ query_states,
561
+ key_states,
562
+ value_states,
563
+ cu_seqlens_q=cu_seqlens_q,
564
+ cu_seqlens_k=cu_seqlens_k,
565
+ max_seqlen_q=max_seqlen_in_batch_q,
566
+ max_seqlen_k=max_seqlen_in_batch_k,
567
+ dropout_p=dropout,
568
+ softmax_scale=softmax_scale,
569
+ causal=causal,
570
+ )
571
+
572
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
573
+ else:
574
+ attn_output = flash_attn_func(
575
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
576
+ )
577
+
578
+ return attn_output
579
+
580
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
581
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
582
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
583
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
584
+
585
+ key_layer = index_first_axis(
586
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
587
+ )
588
+ value_layer = index_first_axis(
589
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
590
+ )
591
+ if query_length == kv_seq_len:
592
+ query_layer = index_first_axis(
593
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
594
+ )
595
+ cu_seqlens_q = cu_seqlens_k
596
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
597
+ indices_q = indices_k
598
+ elif query_length == 1:
599
+ max_seqlen_in_batch_q = 1
600
+ cu_seqlens_q = torch.arange(
601
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
602
+ ) # There is a memcpy here, that is very bad.
603
+ indices_q = cu_seqlens_q[:-1]
604
+ query_layer = query_layer.squeeze(1)
605
+ else:
606
+ # The -q_len: slice assumes left padding.
607
+ attention_mask = attention_mask[:, -query_length:]
608
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
609
+
610
+ return (
611
+ query_layer,
612
+ key_layer,
613
+ value_layer,
614
+ indices_q,
615
+ (cu_seqlens_q, cu_seqlens_k),
616
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
617
+ )
618
+
619
+
620
+ PHI_ATTENTION_CLASSES = {
621
+ "eager": PhiAttention,
622
+ "flash_attention_2": PhiFlashAttention2,
623
+ }
624
+
625
+
626
+ class PhiDecoderLayer(nn.Module):
627
+ def __init__(self, config: PhiConfig, layer_idx: int):
628
+ super().__init__()
629
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
630
+ self.mlp = PhiMLP(config)
631
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
632
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
633
+
634
+ def forward(
635
+ self,
636
+ hidden_states: torch.Tensor,
637
+ attention_mask: Optional[torch.Tensor] = None,
638
+ position_ids: Optional[torch.LongTensor] = None,
639
+ output_attentions: Optional[bool] = False,
640
+ use_cache: Optional[bool] = False,
641
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
642
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
643
+ """
644
+ Args:
645
+ hidden_states (`torch.FloatTensor`):
646
+ input to the layer of shape `(batch, seq_len, embed_dim)`
647
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
648
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
649
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
650
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
651
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
652
+ output_attentions (`bool`, *optional*):
653
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
654
+ returned tensors for more detail.
655
+ use_cache (`bool`, *optional*):
656
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
657
+ (see `past_key_values`).
658
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
659
+ """
660
+
661
+ residual = hidden_states
662
+
663
+ hidden_states = self.input_layernorm(hidden_states)
664
+
665
+ # Self Attention
666
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
667
+ hidden_states=hidden_states,
668
+ attention_mask=attention_mask,
669
+ position_ids=position_ids,
670
+ past_key_value=past_key_value,
671
+ output_attentions=output_attentions,
672
+ use_cache=use_cache,
673
+ )
674
+ attn_outputs = self.resid_dropout(attn_outputs)
675
+
676
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
677
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
678
+ outputs = (hidden_states,)
679
+
680
+ if output_attentions:
681
+ outputs += (self_attn_weights,)
682
+
683
+ if use_cache:
684
+ outputs += (present_key_value,)
685
+
686
+ return outputs
687
+
688
+
689
+ PHI_START_DOCSTRING = r"""
690
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
691
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
692
+ etc.)
693
+
694
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
695
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
696
+ and behavior.
697
+
698
+ Parameters:
699
+ config ([`PhiConfig`]):
700
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
701
+ load the weights associated with the model, only the configuration. Check out the
702
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
703
+ """
704
+
705
+
706
+ @add_start_docstrings(
707
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
708
+ PHI_START_DOCSTRING,
709
+ )
710
+ class PhiPreTrainedModel(PreTrainedModel):
711
+ config_class = PhiConfig
712
+ base_model_prefix = "model"
713
+ supports_gradient_checkpointing = True
714
+ _no_split_modules = ["PhiDecoderLayer"]
715
+ _skip_keys_device_placement = "past_key_values"
716
+ _supports_flash_attn_2 = True
717
+ _supports_cache_class = True
718
+
719
+ def _init_weights(self, module):
720
+ std = self.config.initializer_range
721
+ if isinstance(module, nn.Linear):
722
+ module.weight.data.normal_(mean=0.0, std=std)
723
+ if module.bias is not None:
724
+ module.bias.data.zero_()
725
+ elif isinstance(module, nn.Embedding):
726
+ module.weight.data.normal_(mean=0.0, std=std)
727
+ if module.padding_idx is not None:
728
+ module.weight.data[module.padding_idx].zero_()
729
+
730
+
731
+ PHI_INPUTS_DOCSTRING = r"""
732
+ Args:
733
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
734
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
735
+ it.
736
+
737
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
738
+ [`PreTrainedTokenizer.__call__`] for details.
739
+
740
+ [What are input IDs?](../glossary#input-ids)
741
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
742
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
743
+
744
+ - 1 for tokens that are **not masked**,
745
+ - 0 for tokens that are **masked**.
746
+
747
+ [What are attention masks?](../glossary#attention-mask)
748
+
749
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
750
+ [`PreTrainedTokenizer.__call__`] for details.
751
+
752
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
753
+ `past_key_values`).
754
+
755
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
756
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
757
+ information on the default strategy.
758
+
759
+ - 1 indicates the head is **not masked**,
760
+ - 0 indicates the head is **masked**.
761
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
762
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
763
+ config.n_positions - 1]`.
764
+
765
+ [What are position IDs?](../glossary#position-ids)
766
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
767
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
768
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
769
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
770
+
771
+ Two formats are allowed:
772
+ - a [`~cache_utils.Cache`] instance;
773
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
774
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
775
+ cache format.
776
+
777
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
778
+ legacy cache format will be returned.
779
+
780
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
781
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
782
+ of shape `(batch_size, sequence_length)`.
783
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
784
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
785
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
786
+ model's internal embedding lookup matrix.
787
+ use_cache (`bool`, *optional*):
788
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
789
+ `past_key_values`).
790
+ output_attentions (`bool`, *optional*):
791
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
792
+ tensors for more detail.
793
+ output_hidden_states (`bool`, *optional*):
794
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
795
+ more detail.
796
+ return_dict (`bool`, *optional*):
797
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
798
+ """
799
+
800
+
801
+ @add_start_docstrings(
802
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
803
+ PHI_START_DOCSTRING,
804
+ )
805
+ class PhiModel(PhiPreTrainedModel):
806
+ """
807
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
808
+
809
+ Args:
810
+ config: PhiConfig
811
+ """
812
+
813
+ def __init__(self, config: PhiConfig):
814
+ super().__init__(config)
815
+ self.padding_idx = config.pad_token_id
816
+ self.vocab_size = config.vocab_size
817
+
818
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
819
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
820
+ self.layers = nn.ModuleList(
821
+ [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
822
+ )
823
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
824
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
825
+
826
+ self.gradient_checkpointing = False
827
+ # Initialize weights and apply final processing
828
+ self.post_init()
829
+
830
+ def get_input_embeddings(self):
831
+ return self.embed_tokens
832
+
833
+ def set_input_embeddings(self, value):
834
+ self.embed_tokens = value
835
+
836
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
837
+ def forward(
838
+ self,
839
+ input_ids: torch.LongTensor = None,
840
+ attention_mask: Optional[torch.Tensor] = None,
841
+ position_ids: Optional[torch.LongTensor] = None,
842
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
843
+ inputs_embeds: Optional[torch.FloatTensor] = None,
844
+ use_cache: Optional[bool] = None,
845
+ output_attentions: Optional[bool] = None,
846
+ output_hidden_states: Optional[bool] = None,
847
+ return_dict: Optional[bool] = None,
848
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
849
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
850
+ output_hidden_states = (
851
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
852
+ )
853
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
854
+
855
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
856
+
857
+ # retrieve input_ids and inputs_embeds
858
+ if input_ids is not None and inputs_embeds is not None:
859
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
860
+ elif input_ids is not None:
861
+ batch_size, seq_length = input_ids.shape[:2]
862
+ elif inputs_embeds is not None:
863
+ batch_size, seq_length = inputs_embeds.shape[:2]
864
+ else:
865
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
866
+
867
+ past_key_values_length = 0
868
+
869
+ if self.gradient_checkpointing and self.training:
870
+ if use_cache:
871
+ logger.warning_once(
872
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
873
+ )
874
+ use_cache = False
875
+
876
+ if use_cache:
877
+ use_legacy_cache = not isinstance(past_key_values, Cache)
878
+ if use_legacy_cache:
879
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
880
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
881
+
882
+ if position_ids is None:
883
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
884
+ position_ids = torch.arange(
885
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
886
+ )
887
+ position_ids = position_ids.unsqueeze(0)
888
+
889
+ if inputs_embeds is None:
890
+ inputs_embeds = self.embed_tokens(input_ids)
891
+
892
+ inputs_embeds = self.embed_dropout(inputs_embeds)
893
+
894
+ # Attention mask.
895
+ if self._use_flash_attention_2:
896
+ # 2d mask is passed through the layers
897
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
898
+ else:
899
+ # 4d mask is passed through the layers
900
+ attention_mask = _prepare_4d_causal_attention_mask(
901
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
902
+ )
903
+
904
+ hidden_states = inputs_embeds
905
+
906
+ # decoder layers
907
+ all_hidden_states = () if output_hidden_states else None
908
+ all_self_attns = () if output_attentions else None
909
+ next_decoder_cache = None
910
+
911
+ for decoder_layer in self.layers:
912
+ if output_hidden_states:
913
+ all_hidden_states += (hidden_states,)
914
+
915
+ if self.gradient_checkpointing and self.training:
916
+ layer_outputs = self._gradient_checkpointing_func(
917
+ decoder_layer.__call__,
918
+ hidden_states,
919
+ attention_mask,
920
+ position_ids,
921
+ past_key_values,
922
+ output_attentions,
923
+ )
924
+ else:
925
+ layer_outputs = decoder_layer(
926
+ hidden_states,
927
+ attention_mask=attention_mask,
928
+ position_ids=position_ids,
929
+ past_key_value=past_key_values,
930
+ output_attentions=output_attentions,
931
+ use_cache=use_cache,
932
+ )
933
+
934
+ hidden_states = layer_outputs[0]
935
+
936
+ if use_cache:
937
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
938
+
939
+ if output_attentions:
940
+ all_self_attns += (layer_outputs[1],)
941
+
942
+ hidden_states = self.final_layernorm(hidden_states)
943
+
944
+ # add hidden states from the last decoder layer
945
+ if output_hidden_states:
946
+ all_hidden_states += (hidden_states,)
947
+
948
+ next_cache = None
949
+ if use_cache:
950
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
951
+ if not return_dict:
952
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
953
+ return BaseModelOutputWithPast(
954
+ last_hidden_state=hidden_states,
955
+ past_key_values=next_cache,
956
+ hidden_states=all_hidden_states,
957
+ attentions=all_self_attns,
958
+ )
959
+
960
+
961
+ class PhiForCausalLM(PhiPreTrainedModel):
962
+ _tied_weights_keys = ["lm_head.weight"]
963
+
964
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
965
+ def __init__(self, config):
966
+ super().__init__(config)
967
+ self.model = PhiModel(config)
968
+ self.vocab_size = config.vocab_size
969
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
970
+
971
+ # Initialize weights and apply final processing
972
+ self.post_init()
973
+
974
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
975
+ def get_input_embeddings(self):
976
+ return self.model.embed_tokens
977
+
978
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
979
+ def set_input_embeddings(self, value):
980
+ self.model.embed_tokens = value
981
+
982
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
983
+ def get_output_embeddings(self):
984
+ return self.lm_head
985
+
986
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
987
+ def set_output_embeddings(self, new_embeddings):
988
+ self.lm_head = new_embeddings
989
+
990
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
991
+ def set_decoder(self, decoder):
992
+ self.model = decoder
993
+
994
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
995
+ def get_decoder(self):
996
+ return self.model
997
+
998
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
999
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1000
+ def forward(
1001
+ self,
1002
+ input_ids: torch.LongTensor = None,
1003
+ attention_mask: Optional[torch.Tensor] = None,
1004
+ position_ids: Optional[torch.LongTensor] = None,
1005
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1006
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1007
+ labels: Optional[torch.LongTensor] = None,
1008
+ use_cache: Optional[bool] = None,
1009
+ output_attentions: Optional[bool] = None,
1010
+ output_hidden_states: Optional[bool] = None,
1011
+ return_dict: Optional[bool] = None,
1012
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1013
+ r"""
1014
+ Args:
1015
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1016
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1017
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1018
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1019
+
1020
+ Returns:
1021
+
1022
+ Example:
1023
+
1024
+ ```python
1025
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1026
+
1027
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1028
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1029
+
1030
+ >>> prompt = "This is an example script ."
1031
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1032
+
1033
+ >>> # Generate
1034
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1035
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1036
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1037
+ ```"""
1038
+
1039
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1040
+ output_hidden_states = (
1041
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1042
+ )
1043
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1044
+
1045
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1046
+ outputs = self.model(
1047
+ input_ids=input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ past_key_values=past_key_values,
1051
+ inputs_embeds=inputs_embeds,
1052
+ use_cache=use_cache,
1053
+ output_attentions=output_attentions,
1054
+ output_hidden_states=output_hidden_states,
1055
+ return_dict=return_dict,
1056
+ )
1057
+
1058
+ hidden_states = outputs[0]
1059
+ logits = self.lm_head(hidden_states)
1060
+ logits = logits.float()
1061
+
1062
+ loss = None
1063
+ if labels is not None:
1064
+ # Shift so that tokens < n predict n
1065
+ shift_logits = logits[..., :-1, :].contiguous()
1066
+ shift_labels = labels[..., 1:].contiguous()
1067
+ # Flatten the tokens
1068
+ loss_fct = CrossEntropyLoss()
1069
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1070
+ shift_labels = shift_labels.view(-1)
1071
+ # Enable model parallelism
1072
+ shift_labels = shift_labels.to(shift_logits.device)
1073
+ loss = loss_fct(shift_logits, shift_labels)
1074
+
1075
+ if not return_dict:
1076
+ output = (logits,) + outputs[1:]
1077
+ return (loss,) + output if loss is not None else output
1078
+
1079
+ return CausalLMOutputWithPast(
1080
+ loss=loss,
1081
+ logits=logits,
1082
+ past_key_values=outputs.past_key_values,
1083
+ hidden_states=outputs.hidden_states,
1084
+ attentions=outputs.attentions,
1085
+ )
1086
+
1087
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1088
+ def prepare_inputs_for_generation(
1089
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1090
+ ):
1091
+ if past_key_values is not None:
1092
+ if isinstance(past_key_values, Cache):
1093
+ cache_length = past_key_values.get_seq_length()
1094
+ past_length = past_key_values.seen_tokens
1095
+ max_cache_length = past_key_values.get_max_length()
1096
+ else:
1097
+ cache_length = past_length = past_key_values[0][0].shape[2]
1098
+ max_cache_length = None
1099
+
1100
+ # Keep only the unprocessed tokens:
1101
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1102
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1103
+ # input)
1104
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1105
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1106
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1107
+ # input_ids based on the past_length.
1108
+ elif past_length < input_ids.shape[1]:
1109
+ input_ids = input_ids[:, past_length:]
1110
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1111
+ else:
1112
+ remove_prefix_length = input_ids.shape[1] - 1
1113
+ input_ids = input_ids[:, remove_prefix_length:]
1114
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1115
+ if (
1116
+ max_cache_length is not None
1117
+ and attention_mask is not None
1118
+ and cache_length + input_ids.shape[1] > max_cache_length
1119
+ ):
1120
+ attention_mask = attention_mask[:, -max_cache_length:]
1121
+
1122
+ position_ids = kwargs.get("position_ids", None)
1123
+ if attention_mask is not None and position_ids is None:
1124
+ # create position_ids on the fly for batch generation
1125
+ position_ids = attention_mask.long().cumsum(-1) - 1
1126
+ position_ids.masked_fill_(attention_mask == 0, 1)
1127
+ if past_key_values:
1128
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1129
+
1130
+ if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
1131
+ # generation with static cache
1132
+ seen_tokens = past_key_value.get_seq_length()
1133
+ input_ids = input_ids[:, seen_tokens:]
1134
+ position_ids = position_ids[:, seen_tokens:]
1135
+
1136
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1137
+ if inputs_embeds is not None and past_key_values is None:
1138
+ model_inputs = {"inputs_embeds": inputs_embeds}
1139
+ else:
1140
+ model_inputs = {"input_ids": input_ids}
1141
+
1142
+ model_inputs.update(
1143
+ {
1144
+ "position_ids": position_ids,
1145
+ "past_key_values": past_key_values,
1146
+ "use_cache": kwargs.get("use_cache"),
1147
+ "attention_mask": attention_mask,
1148
+ }
1149
+ )
1150
+ return model_inputs
1151
+
1152
+ @staticmethod
1153
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1154
+ def _reorder_cache(past_key_values, beam_idx):
1155
+ reordered_past = ()
1156
+ for layer_past in past_key_values:
1157
+ reordered_past += (
1158
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1159
+ )
1160
+ return reordered_past
1161
+
1162
+
1163
+ @add_start_docstrings(
1164
+ """
1165
+ The PhiModel with a sequence classification head on top (linear layer).
1166
+
1167
+ [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1168
+ (e.g. GPT-2) do.
1169
+
1170
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1171
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1172
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1173
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1174
+ each row of the batch).
1175
+ """,
1176
+ PHI_START_DOCSTRING,
1177
+ )
1178
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
1179
+ class PhiForSequenceClassification(PhiPreTrainedModel):
1180
+ def __init__(self, config):
1181
+ super().__init__(config)
1182
+ self.num_labels = config.num_labels
1183
+ self.model = PhiModel(config)
1184
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1185
+
1186
+ # Initialize weights and apply final processing
1187
+ self.post_init()
1188
+
1189
+ def get_input_embeddings(self):
1190
+ return self.model.embed_tokens
1191
+
1192
+ def set_input_embeddings(self, value):
1193
+ self.model.embed_tokens = value
1194
+
1195
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1196
+ def forward(
1197
+ self,
1198
+ input_ids: torch.LongTensor = None,
1199
+ attention_mask: Optional[torch.Tensor] = None,
1200
+ position_ids: Optional[torch.LongTensor] = None,
1201
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1202
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1203
+ labels: Optional[torch.LongTensor] = None,
1204
+ use_cache: Optional[bool] = None,
1205
+ output_attentions: Optional[bool] = None,
1206
+ output_hidden_states: Optional[bool] = None,
1207
+ return_dict: Optional[bool] = None,
1208
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1209
+ r"""
1210
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1211
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1212
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1213
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1214
+ """
1215
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1216
+
1217
+ model_outputs = self.model(
1218
+ input_ids,
1219
+ attention_mask=attention_mask,
1220
+ position_ids=position_ids,
1221
+ past_key_values=past_key_values,
1222
+ inputs_embeds=inputs_embeds,
1223
+ use_cache=use_cache,
1224
+ output_attentions=output_attentions,
1225
+ output_hidden_states=output_hidden_states,
1226
+ return_dict=return_dict,
1227
+ )
1228
+ hidden_states = model_outputs[0]
1229
+ logits = self.score(hidden_states)
1230
+
1231
+ if input_ids is not None:
1232
+ batch_size = input_ids.shape[0]
1233
+ else:
1234
+ batch_size = inputs_embeds.shape[0]
1235
+
1236
+ if self.config.pad_token_id is None and batch_size != 1:
1237
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1238
+ if self.config.pad_token_id is None:
1239
+ sequence_lengths = -1
1240
+ else:
1241
+ if input_ids is not None:
1242
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1243
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1244
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1245
+ sequence_lengths = sequence_lengths.to(logits.device)
1246
+ else:
1247
+ sequence_lengths = -1
1248
+
1249
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1250
+
1251
+ loss = None
1252
+ if labels is not None:
1253
+ labels = labels.to(logits.device)
1254
+ if self.config.problem_type is None:
1255
+ if self.num_labels == 1:
1256
+ self.config.problem_type = "regression"
1257
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1258
+ self.config.problem_type = "single_label_classification"
1259
+ else:
1260
+ self.config.problem_type = "multi_label_classification"
1261
+
1262
+ if self.config.problem_type == "regression":
1263
+ loss_fct = MSELoss()
1264
+ if self.num_labels == 1:
1265
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1266
+ else:
1267
+ loss = loss_fct(pooled_logits, labels)
1268
+ elif self.config.problem_type == "single_label_classification":
1269
+ loss_fct = CrossEntropyLoss()
1270
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1271
+ elif self.config.problem_type == "multi_label_classification":
1272
+ loss_fct = BCEWithLogitsLoss()
1273
+ loss = loss_fct(pooled_logits, labels)
1274
+ if not return_dict:
1275
+ output = (pooled_logits,) + model_outputs[1:]
1276
+ return ((loss,) + output) if loss is not None else output
1277
+
1278
+ return SequenceClassifierOutputWithPast(
1279
+ loss=loss,
1280
+ logits=pooled_logits,
1281
+ past_key_values=model_outputs.past_key_values,
1282
+ hidden_states=model_outputs.hidden_states,
1283
+ attentions=model_outputs.attentions,
1284
+ )
1285
+
1286
+
1287
+ @add_start_docstrings(
1288
+ """
1289
+ PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1290
+ Named-Entity-Recognition (NER) tasks.
1291
+ """,
1292
+ PHI_START_DOCSTRING,
1293
+ )
1294
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
1295
+ class PhiForTokenClassification(PhiPreTrainedModel):
1296
+ def __init__(self, config: PhiConfig):
1297
+ super().__init__(config)
1298
+ self.num_labels = config.num_labels
1299
+
1300
+ self.model = PhiModel(config)
1301
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1302
+ classifier_dropout = config.classifier_dropout
1303
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1304
+ classifier_dropout = config.hidden_dropout
1305
+ else:
1306
+ classifier_dropout = 0.1
1307
+ self.dropout = nn.Dropout(classifier_dropout)
1308
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1309
+
1310
+ # Initialize weights and apply final processing
1311
+ self.post_init()
1312
+
1313
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1314
+ @add_code_sample_docstrings(
1315
+ checkpoint=_CHECKPOINT_FOR_DOC,
1316
+ output_type=TokenClassifierOutput,
1317
+ config_class=_CONFIG_FOR_DOC,
1318
+ )
1319
+ def forward(
1320
+ self,
1321
+ input_ids: Optional[torch.LongTensor] = None,
1322
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1323
+ attention_mask: Optional[torch.Tensor] = None,
1324
+ inputs_embeds: Optional[torch.Tensor] = None,
1325
+ labels: Optional[torch.Tensor] = None,
1326
+ use_cache: Optional[bool] = None,
1327
+ output_attentions: Optional[bool] = None,
1328
+ output_hidden_states: Optional[bool] = None,
1329
+ return_dict: Optional[bool] = None,
1330
+ **deprecated_arguments,
1331
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1332
+ r"""
1333
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1334
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1335
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1336
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1337
+ """
1338
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1339
+
1340
+ model_outputs = self.model(
1341
+ input_ids,
1342
+ past_key_values=past_key_values,
1343
+ attention_mask=attention_mask,
1344
+ inputs_embeds=inputs_embeds,
1345
+ use_cache=use_cache,
1346
+ output_attentions=output_attentions,
1347
+ output_hidden_states=output_hidden_states,
1348
+ return_dict=return_dict,
1349
+ )
1350
+
1351
+ hidden_states = model_outputs[0]
1352
+ hidden_states = self.dropout(hidden_states)
1353
+ logits = self.classifier(hidden_states)
1354
+
1355
+ loss = None
1356
+ if labels is not None:
1357
+ # move labels to correct device to enable model parallelism
1358
+ labels = labels.to(logits.device)
1359
+ batch_size, seq_length = labels.shape
1360
+ loss_fct = CrossEntropyLoss()
1361
+ loss = loss_fct(
1362
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1363
+ )
1364
+
1365
+ if not return_dict:
1366
+ output = (logits,) + model_outputs[2:]
1367
+ return ((loss,) + output) if loss is not None else output
1368
+
1369
+ return TokenClassifierOutput(
1370
+ loss=loss,
1371
+ logits=logits,
1372
+ hidden_states=model_outputs.hidden_states,
1373
+ attentions=model_outputs.attentions,
1374
+ )
Unicorn/bunny/model/language_model/phi3/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Microsoft and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ from transformers.utils import (
19
+ OptionalDependencyNotAvailable,
20
+ _LazyModule,
21
+ is_sentencepiece_available,
22
+ is_tokenizers_available,
23
+ is_torch_available,
24
+ )
25
+
26
+
27
+ _import_structure = {
28
+ "configuration_phi3": ["PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP", "Phi3Config"],
29
+ }
30
+
31
+ try:
32
+ if not is_torch_available():
33
+ raise OptionalDependencyNotAvailable()
34
+ except OptionalDependencyNotAvailable:
35
+ pass
36
+ else:
37
+ _import_structure["modeling_phi3"] = [
38
+ "PHI3_PRETRAINED_MODEL_ARCHIVE_LIST",
39
+ "Phi3PreTrainedModel",
40
+ "Phi3Model",
41
+ "Phi3ForCausalLM",
42
+ "Phi3ForSequenceClassification",
43
+ "Phi3ForTokenClassification",
44
+ ]
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from .configuration_phi3 import PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP, Phi3Config
49
+
50
+ try:
51
+ if not is_torch_available():
52
+ raise OptionalDependencyNotAvailable()
53
+ except OptionalDependencyNotAvailable:
54
+ pass
55
+ else:
56
+ from .modeling_phi3 import (
57
+ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST,
58
+ Phi3ForCausalLM,
59
+ Phi3ForSequenceClassification,
60
+ Phi3ForTokenClassification,
61
+ Phi3Model,
62
+ Phi3PreTrainedModel,
63
+ )
64
+
65
+
66
+ else:
67
+ import sys
68
+
69
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Unicorn/bunny/model/language_model/phi3/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
Unicorn/bunny/model/language_model/phi3/__pycache__/configuration_phi3.cpython-310.pyc ADDED
Binary file (8.67 kB). View file