WinstonHu commited on
Commit
759b14f
·
verified ·
1 Parent(s): 65660dc

Delete code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. code/xtuner/.DS_Store +0 -0
  2. code/xtuner/__init__.py +0 -25
  3. code/xtuner/__pycache__/__init__.cpython-311.pyc +0 -0
  4. code/xtuner/__pycache__/entry_point.cpython-311.pyc +0 -0
  5. code/xtuner/__pycache__/registry.cpython-311.pyc +0 -0
  6. code/xtuner/__pycache__/version.cpython-311.pyc +0 -0
  7. code/xtuner/_lite/.DS_Store +0 -0
  8. code/xtuner/_lite/__init__.py +0 -77
  9. code/xtuner/_lite/accelerate/__init__.py +0 -24
  10. code/xtuner/_lite/accelerate/lora.py +0 -5
  11. code/xtuner/_lite/accelerate/ops/__init__.py +0 -4
  12. code/xtuner/_lite/accelerate/ops/moe_permute.py +0 -200
  13. code/xtuner/_lite/accelerate/packed.py +0 -24
  14. code/xtuner/_lite/accelerate/utils.py +0 -62
  15. code/xtuner/_lite/algorithms/.DS_Store +0 -0
  16. code/xtuner/_lite/algorithms/__init__.py +0 -1
  17. code/xtuner/_lite/algorithms/ppo/__init__.py +0 -32
  18. code/xtuner/_lite/algorithms/ppo/dataset.py +0 -153
  19. code/xtuner/_lite/algorithms/ppo/loss.py +0 -119
  20. code/xtuner/_lite/algorithms/ppo/model.py +0 -49
  21. code/xtuner/_lite/algorithms/sft/__init__.py +0 -4
  22. code/xtuner/_lite/algorithms/sft/dataset.py +0 -109
  23. code/xtuner/_lite/chat/.DS_Store +0 -0
  24. code/xtuner/_lite/chat/__init__.py +0 -5
  25. code/xtuner/_lite/chat/backends/__init__.py +0 -1
  26. code/xtuner/_lite/chat/messages/__init__.py +0 -5
  27. code/xtuner/_lite/chat/messages/base.py +0 -32
  28. code/xtuner/_lite/chat/messages/chat.py +0 -202
  29. code/xtuner/_lite/chat/templates/__init__.py +0 -30
  30. code/xtuner/_lite/chat/templates/chat.py +0 -59
  31. code/xtuner/_lite/chat/templates/hybrid.py +0 -206
  32. code/xtuner/_lite/datasets/__init__.py +0 -14
  33. code/xtuner/_lite/datasets/json.py +0 -177
  34. code/xtuner/_lite/datasets/jsonl.py +0 -220
  35. code/xtuner/_lite/datasets/pack.py +0 -257
  36. code/xtuner/_lite/datasets/streaming.py +0 -28
  37. code/xtuner/_lite/datasets/utils/__init__.py +0 -12
  38. code/xtuner/_lite/datasets/utils/convert.py +0 -195
  39. code/xtuner/_lite/datasets/utils/load.py +0 -286
  40. code/xtuner/_lite/datasets/utils/utils.py +0 -66
  41. code/xtuner/_lite/device.py +0 -42
  42. code/xtuner/_lite/modelings/.DS_Store +0 -0
  43. code/xtuner/_lite/modelings/__init__.py +0 -17
  44. code/xtuner/_lite/modelings/internlm2/__init__.py +0 -2
  45. code/xtuner/_lite/modelings/internlm2/configuration_internlm2.py +0 -175
  46. code/xtuner/_lite/modelings/internlm2/modeling_internlm2.py +0 -1899
  47. code/xtuner/_lite/modelings/internlm3/__init__.py +0 -3
  48. code/xtuner/_lite/modelings/internlm3/configuration_internlm3.py +0 -197
  49. code/xtuner/_lite/modelings/internlm3/modeling_internlm3.py +0 -825
  50. code/xtuner/_lite/modelings/internlm3/tokenization_internlm3.py +0 -295
code/xtuner/.DS_Store DELETED
Binary file (12.3 kB)
 
code/xtuner/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import os
3
-
4
- from mmengine.utils import digit_version
5
-
6
- from .entry_point import cli
7
- from .version import __version__, version_info
8
-
9
- HF_CEPH_HUB = os.getenv('HF_CEPH_HUB', '')
10
- HF_USE_CEPH = os.getenv('HF_USE_CEPH', 0) or HF_CEPH_HUB != ''
11
- DS_CEPH_DIR = os.getenv('DS_CEPH_DIR', None)
12
- if HF_USE_CEPH:
13
- from .utils.fileio import (patch_hf_auto_from_pretrained,
14
- patch_hf_save_pretrained)
15
- patch_hf_auto_from_pretrained(HF_CEPH_HUB)
16
- patch_hf_save_pretrained()
17
-
18
- if DS_CEPH_DIR:
19
- from .utils.fileio import patch_deepspeed_engine
20
- patch_deepspeed_engine()
21
-
22
- __all__ = [
23
- '__version__', 'version_info', 'digit_version', 'cli', 'HF_USE_CEPH',
24
- 'DS_CEPH_DIR'
25
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (1.08 kB)
 
code/xtuner/__pycache__/entry_point.cpython-311.pyc DELETED
Binary file (14.7 kB)
 
code/xtuner/__pycache__/registry.cpython-311.pyc DELETED
Binary file (408 Bytes)
 
code/xtuner/__pycache__/version.cpython-311.pyc DELETED
Binary file (1.39 kB)
 
code/xtuner/_lite/.DS_Store DELETED
Binary file (8.2 kB)
 
code/xtuner/_lite/__init__.py DELETED
@@ -1,77 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import os
3
- import subprocess
4
- import sys
5
-
6
- from loguru import logger
7
-
8
- from .device import get_device, get_torch_device_module
9
-
10
- _LOGGER = None
11
-
12
-
13
- def log_format(debug=False):
14
- formatter = "[XTuner][{time:YYYY-MM-DD HH:mm:ss}][<level>{level}</level>]"
15
-
16
- if debug:
17
- formatter += "[<cyan>{name}</cyan>:"
18
- formatter += "<cyan>{function}</cyan>:"
19
- formatter += "<cyan>{line}</cyan>]"
20
-
21
- formatter += " <level>{message}</level>"
22
- return formatter
23
-
24
-
25
- def get_logger(level="INFO"):
26
- global _LOGGER
27
- if _LOGGER is None:
28
- # Remove the original logger in Python to prevent duplicate printing.
29
- logger.remove()
30
- logger.add(sys.stderr, level=level, format=log_format(debug=level == "DEBUG"))
31
- _LOGGER = logger
32
- return _LOGGER
33
-
34
-
35
- def get_repo_git_info(repo_path):
36
- original_directory = os.getcwd()
37
- os.chdir(repo_path)
38
-
39
- try:
40
- branch = (
41
- subprocess.check_output(
42
- ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.STDOUT
43
- )
44
- .strip()
45
- .decode("utf-8")
46
- )
47
-
48
- commit_id = (
49
- subprocess.check_output(
50
- ["git", "rev-parse", "HEAD"], stderr=subprocess.STDOUT
51
- )
52
- .strip()
53
- .decode("utf-8")
54
- )
55
-
56
- remote_url = (
57
- subprocess.check_output(
58
- ["git", "remote", "get-url", "origin"], stderr=subprocess.STDOUT
59
- )
60
- .strip()
61
- .decode("utf-8")
62
- )
63
-
64
- return branch, commit_id, remote_url
65
- except subprocess.CalledProcessError:
66
- return None, None, None
67
- finally:
68
- os.chdir(original_directory)
69
-
70
-
71
- __all__ = [
72
- "AutoConfig",
73
- "AutoModelForCausalLM",
74
- "AutoTokenizer",
75
- "get_device",
76
- "get_torch_device_module",
77
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/accelerate/__init__.py DELETED
@@ -1,24 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from .lora import LORA_TARGET_MAP
3
- from .packed import pack_sequence, unpack_sequence
4
- from .utils import (
5
- liger_kernel_is_available,
6
- lmdeploy_is_available,
7
- mlu_is_available,
8
- npu_is_available,
9
- profile_time_and_memory,
10
- varlen_attn_is_available,
11
- )
12
-
13
- __all__ = [
14
- "LORA_TARGET_MAP",
15
- "pack_sequence",
16
- "packed_sequence",
17
- "unpack_sequence",
18
- "liger_kernel_is_available",
19
- "varlen_attn_is_available",
20
- "lmdeploy_is_available",
21
- "npu_is_available",
22
- "mlu_is_available",
23
- "profile_time_and_memory",
24
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/accelerate/lora.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- LORA_TARGET_MAP = {
3
- "InternLM2ForCausalLM": ["wqkv", "wo", "w1", "w2", "w3"],
4
- "CLIPVisionModel": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
5
- }
 
 
 
 
 
 
code/xtuner/_lite/accelerate/ops/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from .moe_permute import GROUPED_GEMM_INSTALLED, permute_func, unpermute_func
3
-
4
- __all__ = ["GROUPED_GEMM_INSTALLED", "permute_func", "unpermute_func"]
 
 
 
 
 
code/xtuner/_lite/accelerate/ops/moe_permute.py DELETED
@@ -1,200 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- """Modified from
3
- https://github.com/fanshiqing/grouped_gemm/blob/v1.1.4/grouped_gemm/ops.py
4
- Support torch compile."""
5
- from typing import Optional, Tuple
6
-
7
- import torch
8
- from torch import Tensor
9
-
10
- GROUPED_GEMM_INSTALLED = False
11
-
12
- try:
13
- from grouped_gemm import backend
14
-
15
- GROUPED_GEMM_INSTALLED = True
16
- except ImportError:
17
- # install grouped gemm https://github.com/fanshiqing/grouped_gemm/tree/v1.1.4?tab=readme-ov-file#pip-install
18
- grouped_gmm = None
19
-
20
-
21
- @torch.library.custom_op("moe::permute", mutates_args=())
22
- def permute(input_act: Tensor, indices: Tensor, num_topK: int) -> Tuple[Tensor, Tensor]:
23
- input_max_expanded_token_num = input_act.size(0) * num_topK
24
- workspace_fw = []
25
- permuted_act, row_id_map, _ = backend.permute(
26
- input_act, indices, 0, workspace_fw, input_max_expanded_token_num
27
- )
28
- return permuted_act, row_id_map
29
-
30
-
31
- @permute.register_fake
32
- def permute_fake(
33
- input_act: Tensor,
34
- indices: Tensor,
35
- num_topK: int,
36
- ):
37
- permuted_act = input_act.new_empty(
38
- (input_act.shape[0] * num_topK, *input_act.shape[1:])
39
- )
40
- row_id_map = indices.new_empty((indices.numel(),))
41
- return permuted_act, row_id_map
42
-
43
-
44
- @torch.library.custom_op("moe::unpermute", mutates_args=())
45
- def unpermute(
46
- input: Tensor, row_id_map: Tensor, prob: Tensor, max_tokens: int, num_topK: int
47
- ) -> Tensor:
48
- if not input.is_contiguous():
49
- input = input.contiguous()
50
- return backend.unpermute(input, row_id_map, prob, max_tokens, num_topK)
51
-
52
-
53
- @unpermute.register_fake
54
- def unpermute_fake(
55
- input: Tensor, row_id_map: Tensor, prob: Tensor, max_tokens: int, num_topK: int
56
- ) -> Tensor:
57
- return input.new_empty((input.shape[0] // num_topK, *input.shape[1:]))
58
-
59
-
60
- @torch.library.custom_op("moe::unpermute_bwd", mutates_args=())
61
- def unpermute_bwd(
62
- input_bwd: Tensor,
63
- input_fwd: Tensor,
64
- row_id_map: Tensor,
65
- prob: Optional[Tensor],
66
- ) -> Tuple[Tensor, Tensor]:
67
- if not input_bwd.is_contiguous():
68
- input_bwd = input_bwd.contiguous()
69
- topk = input_fwd.shape[0] // input_bwd.shape[0]
70
- if prob is None:
71
- prob = torch.ones(
72
- [input_bwd.size(0), topk], dtype=torch.float32, device=input_bwd.device
73
- )
74
- return backend.unpermute_bwd(input_bwd, input_fwd, row_id_map, prob)
75
-
76
-
77
- @unpermute_bwd.register_fake
78
- def unpermute_bwd_fake(
79
- input_bwd: Tensor,
80
- input_fwd: Tensor,
81
- row_id_map: Tensor,
82
- prob: Optional[Tensor],
83
- ) -> Tuple[Tensor, Tensor]:
84
- act_grad = torch.empty_like(input_fwd)
85
- topk = input_fwd.shape[0] // input_bwd.shape[0]
86
- prob_grad = torch.empty(
87
- (input_bwd.size(0), topk), dtype=torch.float32, device=input_bwd.device
88
- )
89
- return act_grad, prob_grad
90
-
91
-
92
- if torch.__version__ >= "2.4.0":
93
- _wrapped_permute = torch.ops.moe.permute
94
- _wrapped_unpermute = torch.ops.moe.unpermute
95
- _wrapped_unpermute_bwd = torch.ops.moe.unpermute_bwd
96
- else:
97
- _wrapped_permute = permute
98
- _wrapped_unpermute = unpermute
99
- _wrapped_unpermute_bwd = unpermute_bwd
100
-
101
-
102
- class PermuteMoE_topK(torch.autograd.Function):
103
- @staticmethod
104
- def forward(
105
- ctx,
106
- input_act: Tensor,
107
- indices: Tensor,
108
- ):
109
- if not input_act.numel():
110
- return input_act, None
111
-
112
- if indices.dim() == 1:
113
- indices = indices.view(-1, 1)
114
- if not input_act.is_contiguous():
115
- input_act = input_act.contiguous()
116
- if not indices.is_contiguous():
117
- indices = indices.contiguous()
118
-
119
- num_topK = indices.size(1)
120
-
121
- permuted_act, row_id_map = _wrapped_permute(
122
- input_act,
123
- indices,
124
- num_topK,
125
- )
126
-
127
- ctx.row_id_map = row_id_map
128
- ctx.num_tokens = indices.size(0)
129
- ctx.num_topK = num_topK
130
- return permuted_act, row_id_map
131
-
132
- @staticmethod
133
- def backward(ctx, permuted_act_grad, *args):
134
- if not permuted_act_grad.numel():
135
- return permuted_act_grad, None
136
-
137
- permuted_act_grad = permuted_act_grad.contiguous()
138
-
139
- row_id_map = ctx.row_id_map
140
- num_tokens = ctx.num_tokens
141
- num_topK = ctx.num_topK
142
-
143
- unpermuted_act_grad = _wrapped_unpermute(
144
- permuted_act_grad, row_id_map, torch.tensor([]), num_tokens, num_topK
145
- )
146
- return unpermuted_act_grad, None
147
-
148
-
149
- class UnpermuteMoE_topK(torch.autograd.Function):
150
- @staticmethod
151
- def forward(ctx, input_act: Tensor, row_id_map: Tensor, probs: Tensor = None):
152
- if not input_act.numel():
153
- ctx.probs = probs
154
- return input_act
155
-
156
- if not input_act.is_contiguous():
157
- input_act = input_act.contiguous()
158
- if not row_id_map.is_contiguous():
159
- row_id_map = row_id_map.contiguous()
160
- if probs is not None and not probs.is_contiguous():
161
- probs = probs.contiguous()
162
-
163
- num_tokens = probs.size(0) if probs is not None else input_act.size(0)
164
- num_topK = probs.size(1) if probs is not None else 1
165
-
166
- unpermuted_output = _wrapped_unpermute(
167
- input_act,
168
- row_id_map,
169
- probs if probs is not None else torch.tensor([]),
170
- num_tokens,
171
- num_topK,
172
- )
173
-
174
- ctx.save_for_backward(input_act, row_id_map, probs)
175
- return unpermuted_output
176
-
177
- @staticmethod
178
- def backward(ctx, unpermuted_act_grad):
179
- if not unpermuted_act_grad.numel():
180
- return unpermuted_act_grad, None, ctx.probs
181
-
182
- input_act, row_id_map, probs = ctx.saved_tensors
183
-
184
- act_grad = None
185
- if ctx.needs_input_grad[0]:
186
- act_grad, prob_grad = _wrapped_unpermute_bwd(
187
- unpermuted_act_grad, input_act, row_id_map, probs
188
- )
189
-
190
- if not ctx.needs_input_grad[2]:
191
- prob_grad = None
192
- return act_grad, None, prob_grad
193
-
194
-
195
- def permute_func(input_act, indices):
196
- return PermuteMoE_topK.apply(input_act, indices)
197
-
198
-
199
- def unpermute_func(input_act, row_id_map, probs=None):
200
- return UnpermuteMoE_topK.apply(input_act, row_id_map, probs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/accelerate/packed.py DELETED
@@ -1,24 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from typing import List, Union
3
-
4
- import torch
5
-
6
-
7
- def unpack_sequence(packed: torch.Tensor, num_tokens: Union[torch.Tensor, List], dim=1):
8
- if isinstance(num_tokens, torch.Tensor):
9
- num_tokens = num_tokens.tolist()
10
- sequences = torch.split(packed, num_tokens, dim=dim)
11
- return sequences
12
-
13
-
14
- def pack_sequence(sequences, dim=1):
15
- num_tokens = torch.IntTensor([seq.size(dim) for seq in sequences])
16
- packed = torch.cat(sequences, dim=dim)
17
- return packed, num_tokens.to(packed.device)
18
-
19
-
20
- def packed_cumulative_length(num_tokens: torch.Tensor):
21
- device = num_tokens.device
22
- _zero_pad = torch.zeros(1, device=device)
23
- _pad_length = torch.cat([_zero_pad, num_tokens]).int()
24
- return torch.cumsum(_pad_length, 0).int()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/accelerate/utils.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import time
3
- from contextlib import contextmanager
4
-
5
- from transformers.utils.import_utils import is_flash_attn_2_available
6
-
7
- from xtuner._lite import get_device, get_logger, get_torch_device_module
8
-
9
- logger = get_logger()
10
-
11
-
12
- def npu_is_available():
13
- return get_device() == "npu"
14
-
15
-
16
- def mlu_is_available():
17
- return get_device() == "mlu"
18
-
19
-
20
- def varlen_attn_is_available():
21
- return is_flash_attn_2_available() or npu_is_available()
22
-
23
-
24
- def lmdeploy_is_available():
25
- available = False
26
- try:
27
- import lmdeploy # noqa: F401
28
-
29
- available = True
30
- except ImportError:
31
- available = False
32
-
33
- return available
34
-
35
-
36
- def liger_kernel_is_available():
37
- available = False
38
- try:
39
- import liger_kernel # noqa: F401
40
-
41
- available = True
42
- except ImportError:
43
- available = False
44
-
45
- return available
46
-
47
-
48
- @contextmanager
49
- def profile_time_and_memory(desc):
50
- torch_device = get_torch_device_module()
51
- start_t = time.time()
52
- torch_device.reset_peak_memory_stats()
53
-
54
- yield
55
-
56
- max_memory = torch_device.max_memory_allocated()
57
- cost_time = time.time() - start_t
58
-
59
- logger.success(
60
- f"{desc} Elapsed time {cost_time:.2f} seconds, "
61
- f"peak gpu memory {max_memory/1024**3:.1f}G"
62
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/algorithms/.DS_Store DELETED
Binary file (6.15 kB)
 
code/xtuner/_lite/algorithms/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
 
 
code/xtuner/_lite/algorithms/ppo/__init__.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from .dataset import (
3
- InferDataset,
4
- PPOTokenizeFunction,
5
- RewardBuffer,
6
- RewardBufferCollator,
7
- )
8
- from .loss import (
9
- CriticLoss,
10
- PPOPolicyLoss,
11
- compute_advantages_and_returns,
12
- compute_kl_rewards,
13
- gather_logprobs,
14
- )
15
- from .model import build_actor_model, build_reward_model
16
-
17
- __all__ = [
18
- "InferDataset",
19
- "RewardBuffer",
20
- "RewardBufferCollator",
21
- "PPOCollator",
22
- "PPODataset",
23
- "PPOTokenizeFunction",
24
- "CriticLoss",
25
- "PPOPolicyLoss",
26
- "compute_advantages_and_returns",
27
- "compute_kl_rewards",
28
- "compute_rewards",
29
- "gather_logprobs",
30
- "build_actor_model",
31
- "build_reward_model",
32
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/algorithms/ppo/dataset.py DELETED
@@ -1,153 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import json
3
-
4
- import numpy as np
5
- import torch
6
- from torch import nn
7
-
8
- from xtuner._lite.chat.messages.chat import ChatMsg
9
- from xtuner._lite.datasets import OPENAI_CONVERT_MAP
10
-
11
- from ..sft import SftCollator, SftTokenizeFunction
12
-
13
-
14
- class InferDataset(torch.utils.data.Dataset):
15
- def __init__(self, prompts, responses):
16
- super().__init__()
17
-
18
- assert len(prompts) == len(responses)
19
- self.prompts = prompts
20
- self.responses = responses
21
- self.policies = None
22
-
23
- def __len__(self):
24
- return len(self.prompts)
25
-
26
- def __getitem__(self, item):
27
- prompt = self.prompts[item]
28
- response = self.responses[item]
29
- num_prefill_tokens = len(prompt)
30
-
31
- input_ids = prompt + response
32
- labels = [-100] * (num_prefill_tokens - 1) + response + [-100]
33
-
34
- return {"input_ids": input_ids, "labels": labels, "num_tokens": len(input_ids)}
35
-
36
-
37
- FASTER = False
38
-
39
-
40
- class RewardBuffer(torch.utils.data.Dataset):
41
- def __init__(self, clip_min=-5, clip_max=5, normalize=True, faster=False):
42
- super().__init__()
43
-
44
- self.clip_min = clip_min
45
- self.clip_max = clip_max
46
-
47
- self.normalize = normalize
48
-
49
- if self.normalize:
50
- self.bn = nn.BatchNorm1d(1, momentum=None, affine=False)
51
- else:
52
- self.bn = None
53
-
54
- self._num_action_tokens = 0
55
- self._num_total_tokens = 0
56
- self._trajectories = []
57
-
58
- self._current_mean = 0
59
-
60
- @property
61
- def running_mean(self):
62
- return self.bn.running_mean.item()
63
-
64
- @property
65
- def current_mean(self):
66
- return self._current_mean
67
-
68
- @property
69
- def num_action_tokens(self):
70
- return self._num_action_tokens.item()
71
-
72
- @property
73
- def num_total_tokens(self):
74
- return self._num_total_tokens
75
-
76
- def update(self, trajectories):
77
- rewards = [data["reward"] for data in trajectories]
78
-
79
- for i in range(len(trajectories)):
80
- trajectories[i]["ori_reward"] = trajectories[i]["reward"]
81
-
82
- rewards = torch.tensor(rewards)
83
-
84
- self._current_mean = rewards.mean().item()
85
-
86
- rewards = rewards.clip(self.clip_min, self.clip_max)
87
-
88
- if self.normalize:
89
- self.bn.train()
90
- _ = self.bn(rewards.unsqueeze(-1))
91
- self.bn.eval()
92
- rewards = self.bn(rewards.unsqueeze(-1))
93
-
94
- for i in range(len(trajectories)):
95
- trajectories[i]["reward"] = rewards[i].item()
96
-
97
- num_total_tokens = 0
98
- num_action_tokens = 0
99
- for data in trajectories:
100
- labels = np.array(data["labels"])
101
- num_total_tokens += labels.size
102
- num_action_tokens += (labels >= 0).sum()
103
-
104
- self._num_action_tokens = num_action_tokens
105
- self._num_total_tokens = num_total_tokens
106
-
107
- self._trajectories = trajectories
108
-
109
- def dump_jsonl(self, path, tokenizer, debug=False):
110
- with open(path, "w", encoding="utf8") as f:
111
- for data in self._trajectories:
112
- json_line = {
113
- "num_tokens": data["num_tokens"],
114
- "reward": data["ori_reward"],
115
- "sequence": tokenizer.decode(data["input_ids"]),
116
- }
117
-
118
- if debug:
119
- json_line["input_ids"] = data["input_ids"]
120
- json_line["labels"] = data["labels"]
121
-
122
- json_str = json.dumps(json_line, ensure_ascii=False)
123
- f.write(json_str + "\n")
124
-
125
- def __len__(self):
126
- return len(self._trajectories)
127
-
128
- def __getitem__(self, item):
129
- return self._trajectories[item]
130
-
131
-
132
- class PPOTokenizeFunction(SftTokenizeFunction):
133
- def __init__(self, tokenizer, chat_template, raw_format="openai", sys_prompt=None):
134
- super().__init__(tokenizer, chat_template, raw_format)
135
- self.sys_prompt = sys_prompt
136
-
137
- def __call__(self, item):
138
- formatter = OPENAI_CONVERT_MAP[self.raw_format]
139
- msg = formatter(item)
140
- if self.sys_prompt is not None:
141
- sys_msg = ChatMsg(role="system", content=self.sys_prompt)
142
- msg.messages = [sys_msg] + msg.messages
143
- tokenized = msg.tokenize(self.tokenizer, self.chat_template)
144
-
145
- return tokenized
146
-
147
-
148
- class RewardBufferCollator(SftCollator):
149
- def __call__(self, instances):
150
- data = super().__call__(instances)
151
- data["rewards"] = [item["reward"] for item in instances]
152
-
153
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/algorithms/ppo/loss.py DELETED
@@ -1,119 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import torch
3
- from torch.nn import functional as F
4
-
5
- from xtuner._lite import get_logger
6
-
7
- logger = get_logger()
8
-
9
-
10
- def gather_logprobs(logits, labels):
11
- log_probs = F.log_softmax(logits, dim=-1)
12
- log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
13
- return log_probs_labels.squeeze(-1)
14
-
15
-
16
- @torch.no_grad()
17
- def compute_kl_rewards(logprobs, ref_logprobs, reward_score, kl_coef=0.01):
18
- assert logprobs.ndim == 1
19
- last_mask = torch.zeros_like(logprobs, dtype=torch.int)
20
- last_mask[-1] = 1
21
-
22
- kl = ref_logprobs - logprobs
23
- kl_reward = kl_coef * kl * (1 - last_mask)
24
-
25
- last_reward = reward_score * last_mask
26
-
27
- rewards = kl_reward + last_reward
28
-
29
- return rewards
30
-
31
-
32
- @torch.no_grad()
33
- def compute_advantages_and_returns(values, rewards, gamma=1.0, gae_lambda=0.99):
34
- # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501
35
- """Function that computes advantages and returns from rewards and values.
36
- Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347
37
- Note that rewards may include a KL divergence loss term.
38
-
39
- Advantages looks like this:
40
- Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
41
- - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
42
-
43
- Returns looks like this:
44
- Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
45
- + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
46
- """
47
- lastgaelam = 0
48
- advantages_reversed = []
49
-
50
- assert values.numel() == rewards.numel(), f"{values.numel()}, {rewards.numel()}"
51
- length = rewards.numel()
52
-
53
- for t in reversed(range(0, length)):
54
- nextvalues = values[t + 1] if t < length - 1 else 0.0
55
- # Since old_rewards and old_values are masked with action_mask,
56
- # i.e. they have 0's at pad tokens,
57
- # delta will be 0 if current t is at a pad token,
58
- # so will lastgaelam
59
- delta = rewards[t] + gamma * nextvalues - values[t]
60
- lastgaelam = delta + gamma * gae_lambda * lastgaelam
61
- advantages_reversed.append(lastgaelam)
62
-
63
- advantages = torch.stack(advantages_reversed[::-1], dim=0)
64
- returns = advantages + values
65
- return advantages.detach(), returns
66
-
67
-
68
- class CriticLoss(torch.nn.Module):
69
- """Loss function for critic model."""
70
-
71
- def __init__(self, cliprange_value: float = 0.5, loss_type: str = "per_seq"):
72
- super().__init__()
73
- self.cliprange_value = cliprange_value
74
- self.loss_type = loss_type
75
-
76
- assert self.loss_type in ["per_token", "per_seq"]
77
-
78
- def critic_loss_fn(self, values, old_values, returns, loss_factor=None):
79
- values_clipped = old_values + (values - old_values).clamp(
80
- -self.cliprange_value, self.cliprange_value
81
- )
82
- vf_loss1 = (values_clipped - returns) ** 2
83
- vf_loss2 = (values - returns) ** 2
84
- if self.loss_type == "per_seq":
85
- vf_loss = torch.max(vf_loss1, vf_loss2).mean(-1)
86
- elif self.loss_type == "per_token":
87
- assert loss_factor is not None
88
- vf_loss = torch.sum(torch.max(vf_loss1, vf_loss2) * loss_factor)
89
- return 0.5 * vf_loss
90
-
91
- def forward(self, values: torch.Tensor, old_values, returns, loss_factor=None):
92
- loss = self.critic_loss_fn(
93
- values=values,
94
- old_values=old_values,
95
- returns=returns,
96
- loss_factor=loss_factor,
97
- )
98
- return loss
99
-
100
-
101
- class PPOPolicyLoss(torch.nn.Module):
102
- """Loss function for policy model."""
103
-
104
- def __init__(self, cliprange: float = 0.2, loss_type: str = "per_seq"):
105
- super().__init__()
106
- self.cliprange = cliprange
107
- self.loss_type = loss_type
108
- assert self.loss_type in ["per_token", "per_seq"]
109
-
110
- def forward(self, logprobs, old_logprobs, advantages, loss_factor=None):
111
- ratio = (logprobs - old_logprobs).exp()
112
- pg_loss1 = -ratio * advantages
113
- pg_loss2 = -ratio.clamp(1 - self.cliprange, 1 + self.cliprange) * advantages
114
- if self.loss_type == "per_seq":
115
- pg_loss = torch.max(pg_loss1, pg_loss2).mean(dim=-1)
116
- elif self.loss_type == "per_token":
117
- assert loss_factor is not None
118
- pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2)) * loss_factor
119
- return pg_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/algorithms/ppo/model.py DELETED
@@ -1,49 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import torch
3
- from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
- from transformers.utils.import_utils import (
5
- is_flash_attn_2_available,
6
- is_torch_sdpa_available,
7
- )
8
-
9
- from xtuner._lite.accelerate import LoadWoInit
10
-
11
-
12
- def build_actor_model(model_path, dtype=torch.float32, trust_remote_code=True):
13
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
14
- if is_flash_attn_2_available():
15
- config.attn_implementation = "flash_attention_2"
16
- elif is_torch_sdpa_available():
17
- config.attn_implementation = "sdpa"
18
-
19
- with LoadWoInit():
20
- policy = AutoModelForCausalLM.from_pretrained(
21
- model_path,
22
- attn_implementation="flash_attention_2",
23
- torch_dtype=dtype,
24
- trust_remote_code=trust_remote_code,
25
- )
26
-
27
- return policy
28
-
29
-
30
- def build_reward_model(model_path, dtype=torch.float32, trust_remote_code=True):
31
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
32
- if is_flash_attn_2_available():
33
- config.attn_implementation = "flash_attention_2"
34
- elif is_torch_sdpa_available():
35
- config.attn_implementation = "sdpa"
36
-
37
- config.use_cache = False
38
- config.torch_dtype = dtype
39
- with LoadWoInit():
40
- reward = AutoModel.from_pretrained(
41
- model_path,
42
- attn_implementation="flash_attention_2",
43
- torch_dtype=dtype,
44
- trust_remote_code=trust_remote_code,
45
- )
46
-
47
- reward.model.use_cache = False
48
-
49
- return reward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/algorithms/sft/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from .dataset import SftCollator, SftTokenizeFunction
3
-
4
- __all__ = ["SftCollator", "SftTokenizeFunction"]
 
 
 
 
 
code/xtuner/_lite/algorithms/sft/dataset.py DELETED
@@ -1,109 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import torch
3
- from torch.nn.utils.rnn import pad_sequence
4
-
5
- from xtuner._lite import get_logger
6
- from xtuner._lite.datasets import OPENAI_CONVERT_MAP
7
-
8
- logger = get_logger()
9
-
10
-
11
- class SftTokenizeFunction:
12
- def __init__(self, tokenizer, chat_template, raw_format="openai"):
13
- self.tokenizer = tokenizer
14
- self.chat_template = chat_template
15
- self.raw_format = raw_format
16
-
17
- def __call__(self, item):
18
- formatter = OPENAI_CONVERT_MAP[self.raw_format]
19
- msg = formatter(item)
20
- tokenized = msg.tokenize(self.tokenizer, self.chat_template)
21
- return tokenized
22
-
23
-
24
- class SftCollator:
25
- def __init__(
26
- self, pad_token_id=0, ignore_id=-100, pack_batch=False, max_length=None
27
- ):
28
- self.pack_batch = pack_batch
29
- self.pad_token_id = pad_token_id
30
- self.ignore_id = ignore_id
31
- self.max_length = max_length
32
-
33
- def __call__(self, instances):
34
- _instances = []
35
- for ins in instances:
36
- if isinstance(ins, list):
37
- _instances.extend(ins)
38
- else:
39
- _instances.append(ins)
40
-
41
- instances = _instances
42
-
43
- input_ids = []
44
- labels = []
45
- num_tokens = []
46
-
47
- for data in instances:
48
- _input_ids = data["input_ids"]
49
- _labels = data["labels"]
50
- _num_tokens = data["num_tokens"]
51
-
52
- # TODO remove list
53
- if isinstance(_num_tokens, list):
54
- assert len(_num_tokens) == 1
55
- _num_tokens = _num_tokens[0]
56
-
57
- assert isinstance(_num_tokens, int)
58
-
59
- if self.max_length:
60
- _input_ids = _input_ids[: self.max_length]
61
- _labels = _labels[: self.max_length]
62
- _num_tokens = min(_num_tokens, self.max_length)
63
-
64
- input_ids.append(torch.LongTensor(_input_ids))
65
- labels.append(torch.LongTensor(_labels))
66
- num_tokens.append(_num_tokens)
67
-
68
- attention_mask = [torch.ones_like(ids) for ids in input_ids]
69
- num_tokens = torch.IntTensor(num_tokens)
70
-
71
- if len(instances) > 1 and self.pack_batch:
72
- input_ids = torch.cat(input_ids, dim=0).unsqueeze(0)
73
- labels = torch.cat(labels, dim=0).unsqueeze(0)
74
- attention_mask = torch.cat(attention_mask, dim=0).unsqueeze(0)
75
-
76
- elif len(instances) > 1 and not self.pack_batch:
77
- input_ids = pad_sequence(
78
- input_ids, batch_first=True, padding_value=self.pad_token_id
79
- )
80
- labels = pad_sequence(
81
- labels, batch_first=True, padding_value=self.ignore_id
82
- )
83
- attention_mask = pad_sequence(
84
- attention_mask, batch_first=True, padding_value=0
85
- )
86
- else:
87
- input_ids = torch.stack(input_ids)
88
- labels = torch.stack(labels)
89
- attention_mask = torch.stack(attention_mask)
90
-
91
- if input_ids.shape != labels.shape:
92
- logger.error(f"[instances] {instances}")
93
- logger.error(f"[num_tokens] {num_tokens}")
94
- logger.error(f"[input_ids] {input_ids}")
95
- logger.error(f"[labels] {labels}")
96
- raise RuntimeError(
97
- "The shape of input_ids and labels must be "
98
- f"equal, but found {input_ids.shape} and "
99
- f"{labels.shape}."
100
- )
101
-
102
- data_dict = {
103
- "input_ids": input_ids,
104
- "labels": labels,
105
- "num_tokens": num_tokens,
106
- "attention_mask": attention_mask.bool(),
107
- }
108
-
109
- return data_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/chat/.DS_Store DELETED
Binary file (6.15 kB)
 
code/xtuner/_lite/chat/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from .messages import ChatMessages
3
- from .templates import CHAT_TEMPLATE_MAP, ChatTemplate, HybridChatTemplate
4
-
5
- __all__ = ["ChatMessages", "CHAT_TEMPLATE_MAP", "ChatTemplate", "HybridChatTemplate"]
 
 
 
 
 
 
code/xtuner/_lite/chat/backends/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
 
 
code/xtuner/_lite/chat/messages/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from .base import BaseMessages
3
- from .chat import ChatMessages
4
-
5
- __all__ = ["BaseMessages", "ChatMessages"]
 
 
 
 
 
 
code/xtuner/_lite/chat/messages/base.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from abc import abstractclassmethod, abstractmethod
3
- from typing import Dict
4
-
5
- from pydantic import BaseModel
6
- from transformers import PreTrainedTokenizer
7
-
8
- from ..templates import ChatTemplate
9
-
10
-
11
- class BaseMessages(BaseModel):
12
- @abstractmethod
13
- def add(self, role: str, content):
14
- pass
15
-
16
- @abstractmethod
17
- def pop(self):
18
- pass
19
-
20
- @abstractmethod
21
- def get_prompt(self, chat_template: ChatTemplate) -> str:
22
- pass
23
-
24
- @abstractmethod
25
- def tokenize(
26
- self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate
27
- ) -> Dict:
28
- pass
29
-
30
- @abstractclassmethod
31
- def from_dict(cls, item: Dict) -> "BaseMessages":
32
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/chat/messages/chat.py DELETED
@@ -1,202 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
- from typing import Dict, List, Literal, Optional, Union
4
-
5
- from pydantic import BaseModel
6
- from transformers import PreTrainedTokenizer
7
-
8
- from xtuner._lite import get_logger
9
- from xtuner.utils import IGNORE_INDEX
10
-
11
- from ..templates import ChatTemplate, HybridChatTemplate
12
- from .base import BaseMessages
13
-
14
- logger = get_logger()
15
-
16
-
17
- class TextContentItem(BaseModel):
18
- type: Literal["text"] = "text"
19
- text: str
20
-
21
- def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
22
- return self.text
23
-
24
-
25
- class ImageContentItem(BaseModel):
26
- type: Literal["image_url"] = "image_url"
27
- image_url: str
28
-
29
- def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
30
- return chat_template.image_token
31
-
32
-
33
- MultModalContentType = Union[TextContentItem, ImageContentItem]
34
- ContentType = Union[str, List[MultModalContentType]]
35
-
36
-
37
- class ChatMsg(BaseModel):
38
- role: Literal["assistant", "user", "system"]
39
- content: ContentType
40
- loss: Optional[bool] = None
41
-
42
- def __init__(self, *args, **kwargs):
43
- super().__init__(*args, **kwargs)
44
- if self.loss is None:
45
- if self.role == "system":
46
- self.loss = False
47
- elif self.role == "user":
48
- self.loss = False
49
- elif self.role == "assistant":
50
- self.loss = True
51
- else:
52
- raise NotImplementedError
53
-
54
- def collect_img_urls(self) -> List[str]:
55
- img_urls = []
56
- if isinstance(self.content, list):
57
- for item in self.content:
58
- if isinstance(item, ImageContentItem):
59
- img_urls.append(item.image_url)
60
- return img_urls
61
-
62
- def get_prompt(self, chat_template: ChatTemplate) -> str:
63
- if isinstance(self.content, str):
64
- text = self.content
65
- elif isinstance(self.content, list):
66
- text = ""
67
- for i, item in enumerate(self.content):
68
- if i == 0:
69
- text += item.apply_chat_template(chat_template)
70
- else:
71
- text += "\n" + item.apply_chat_template(chat_template)
72
- else:
73
- raise NotImplementedError
74
-
75
- if self.role == "system":
76
- prompt = chat_template.decorate_system(text)
77
- elif self.role == "user":
78
- prompt = chat_template.decorate_user(text)
79
- elif self.role == "assistant":
80
- prompt = chat_template.decorate_assistant(text)
81
- else:
82
- raise NotImplementedError
83
-
84
- return prompt
85
-
86
- def tokenize(
87
- self,
88
- tokenizer: PreTrainedTokenizer,
89
- chat_template: ChatTemplate,
90
- ):
91
- decorated = self.get_prompt(chat_template)
92
-
93
- token_ids = tokenizer.encode(decorated, add_special_tokens=False)
94
-
95
- if self.loss:
96
- label_ids = copy.deepcopy(token_ids)
97
- else:
98
- label_ids = [IGNORE_INDEX] * len(token_ids)
99
-
100
- return {
101
- "input_ids": token_ids,
102
- "labels": label_ids,
103
- }
104
-
105
-
106
- class ChatMessages(BaseMessages):
107
- messages: List[ChatMsg]
108
-
109
- def add(self, role, content, loss=False):
110
- self.messages.append(ChatMsg(role=role, content=content, loss=loss))
111
-
112
- def pop(self):
113
- return self.messages.pop()
114
-
115
- def get_prompt(self, chat_template: ChatTemplate) -> str:
116
- prompt = ""
117
-
118
- for msg in self.messages:
119
- prompt += msg.get_prompt(chat_template)
120
- if msg.role == "assistant":
121
- prompt += chat_template.sep
122
- return prompt
123
-
124
- def tokenize(
125
- self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate
126
- ) -> Dict:
127
- input_ids = tokenizer.encode("", add_special_tokens=True)
128
- labels = [IGNORE_INDEX for _ in input_ids]
129
- image_urls = []
130
-
131
- for msg in self.messages:
132
- res = msg.tokenize(tokenizer, chat_template)
133
- token_ids, label_ids = res["input_ids"], res["labels"]
134
-
135
- input_ids.extend(token_ids)
136
- labels.extend(label_ids)
137
-
138
- image_urls.extend(msg.collect_img_urls())
139
-
140
- if msg.role == "assistant":
141
- sep = chat_template.sep
142
- sep_tokens = tokenizer.encode(sep, add_special_tokens=False)
143
- input_ids.extend(sep_tokens)
144
- labels.extend([IGNORE_INDEX] * len(sep_tokens))
145
-
146
- if len(input_ids) != len(labels):
147
- logger.error(f"[messages] {self.messages}")
148
- logger.error(f"[input_ids] {input_ids}")
149
- logger.error(f"[labels] {labels}")
150
- raise RuntimeError(
151
- "The lengths of input_ids and labels must be "
152
- f"equal, but found {len(input_ids)} and "
153
- f"{len(labels)}."
154
- )
155
-
156
- training_data = {
157
- "input_ids": input_ids,
158
- "labels": labels,
159
- "num_tokens": len(input_ids),
160
- }
161
-
162
- if len(image_urls) > 0:
163
- training_data["image_urls"] = image_urls
164
-
165
- return training_data
166
-
167
- @classmethod
168
- def from_str(cls, prompt: str) -> "ChatMessages":
169
- msg = ChatMsg(role="user", content=prompt)
170
- return cls(messages=[msg])
171
-
172
- @classmethod
173
- def from_dict(cls, item: dict) -> "ChatMessages":
174
- """
175
- item
176
- {
177
- 'messages':[
178
- {'role':'user', 'content':'hello'},
179
- {'role':'assistant', 'content':'hello!'},
180
- ],
181
- }
182
- """
183
- return cls(**item)
184
-
185
-
186
- if __name__ == "__main__":
187
- data = {
188
- "messages": [
189
- {"role": "user", "content": "hello"},
190
- {"role": "assistant", "content": "hello!"},
191
- ]
192
- }
193
-
194
- messages = ChatMessages.from_dict(data)
195
- chat_template = ChatTemplate(
196
- system="<|im_start|>system\n{system}<|im_end|>\n",
197
- user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
198
- assistant="{assistant}<|im_end|>\n",
199
- stop_words=["<|im_end|>"],
200
- )
201
-
202
- print(messages.get_prompt(chat_template))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/chat/templates/__init__.py DELETED
@@ -1,30 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from .chat import ChatTemplate
3
- from .hybrid import HybridChatTemplate
4
-
5
- CHAT_TEMPLATE_MAP = {
6
- "internlm2": HybridChatTemplate(
7
- system="<|im_start|>system\n{system}<|im_end|>\n",
8
- user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
9
- assistant="{assistant}<|im_end|>",
10
- stop_words=["<|im_end|>"],
11
- ),
12
- "qwen2": HybridChatTemplate(
13
- system="<|im_start|>system\n{system}<|im_end|>\n",
14
- user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
15
- assistant="{assistant}<|im_end|>",
16
- stop_words=["<|im_end|>", "<|endoftext|>"],
17
- ),
18
- "llama3": HybridChatTemplate(
19
- system=("<|start_header_id|>system<|end_header_id|>\n\n{system}" "<|eot_id|>"),
20
- user=(
21
- "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>"
22
- "<|start_header_id|>assistant<|end_header_id|>\n\n"
23
- ),
24
- assistant="{assistant}<|eot_id|>",
25
- sep="",
26
- stop_words=["<|eot_id|>"],
27
- ),
28
- }
29
-
30
- __all__ = ["ChatTemplate", "HybridChatTemplate"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/chat/templates/chat.py DELETED
@@ -1,59 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from typing import List
3
-
4
- from pydantic import BaseModel, field_validator
5
-
6
-
7
- class ChatTemplate(BaseModel):
8
- """Define a Pydantic data model for a hybrid chat with attributes for
9
- system, user and assistant chat as well as function and interpreter calls
10
- and results."""
11
-
12
- # Normal Chat
13
- system: str # System message format
14
- user: str # User message format
15
- assistant: str # Assistant message format
16
- stop_words: List[str] # List of stop words
17
- sep: str = "\n"
18
-
19
- def decorate_system(self, text: str) -> str:
20
- """Decorate text with the `system` template."""
21
- return self.system.format(system=text)
22
-
23
- def decorate_assistant(self, text: str) -> str:
24
- """Decorate text with the `assistant` template."""
25
- return self.assistant.format(assistant=text)
26
-
27
- def decorate_user(self, text: str) -> str:
28
- """Decorate text with the `user` template."""
29
- return self.user.format(user=text)
30
-
31
- @field_validator("system")
32
- def check_system(cls, v: str) -> str:
33
- """Validate that `system` contains '{system}'.
34
-
35
- If not, raises a ValueError.
36
- """
37
- if v is not None and "{system}" not in v:
38
- raise ValueError("system must contain the keyword '{system}'")
39
- return v
40
-
41
- @field_validator("user")
42
- def check_user(cls, v: str) -> str:
43
- """Validate that `user` contains '{user}'.
44
-
45
- If not, raises a ValueError.
46
- """
47
- if v is not None and "{user}" not in v:
48
- raise ValueError("user must contain the keyword '{user}'")
49
- return v
50
-
51
- @field_validator("assistant")
52
- def check_assistant(cls, v: str) -> str:
53
- """Validate that `assistant` contains '{assistant}'.
54
-
55
- If not, raises a ValueError.
56
- """
57
- if v is not None and "{assistant}" not in v:
58
- raise ValueError("assistant must contain the keyword '{assistant}'")
59
- return v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/chat/templates/hybrid.py DELETED
@@ -1,206 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from typing import Dict, List, Optional
3
-
4
- from pydantic import BaseModel, field_validator
5
-
6
-
7
- class HybridChatTemplate(BaseModel):
8
- """Define a Pydantic data model for a hybrid chat with attributes for
9
- system, user and assistant chat as well as function and interpreter calls
10
- and results."""
11
-
12
- # Normal Chat
13
- system: str # System message format
14
- user: str # User message format
15
- assistant: str # Assistant message format
16
- stop_words: List[str] # List of stop words
17
- sep: str = "\n"
18
-
19
- # Multimodal Chat
20
- # Predefined token and index for images
21
- image_token: str = "<image>"
22
- image_token_index: int = -100
23
-
24
- # Agent Chat
25
-
26
- # Interpreter and function related strings
27
- files: Optional[str] = None
28
-
29
- functions: Optional[str] = None # Function description format
30
- function_call: Optional[str] = None # Function call format
31
- function_result: Optional[str] = None # Function result format
32
-
33
- code_interpreter: Optional[str] = None
34
- code_interpreter_call: Optional[str] = None # Interpreter call format
35
- code_interpreter_result: Optional[str] = None # Interpreter result format
36
-
37
- function_token: Optional[str] = None
38
- code_interpreter_token: Optional[str] = None
39
- action_start_token: Optional[str] = None
40
- action_end_token: Optional[str] = None
41
-
42
- @property
43
- def mm_token_maps(self) -> Dict[str, int]:
44
- """Return a dictionary that maps multimodal tokens to corresponding
45
- token indexes."""
46
- return {self.image_token: self.image_token_index}
47
-
48
- def decorate_system(self, text: str) -> str:
49
- """Decorate text with the `system` template."""
50
- return self.system.format(system=text)
51
-
52
- def decorate_assistant(self, text: str) -> str:
53
- """Decorate text with the `assistant` template."""
54
- return self.assistant.format(assistant=text)
55
-
56
- def decorate_user(self, text: str) -> str:
57
- """Decorate text with the `user` template."""
58
- return self.user.format(user=text)
59
-
60
- def decorate_files(self, text: str) -> str:
61
- """Decorate text with the `functions` template."""
62
- return self.files.format(files=text)
63
-
64
- def decorate_functions(self, text: str) -> str:
65
- """Decorate text with the `functions` template."""
66
- return self.functions.format(functions=text)
67
-
68
- def decorate_function_call(self, text: str, func: str) -> str:
69
- """Decorate text with the `function_call` template."""
70
- return self.function_call.format(assistant=text, function_call=func)
71
-
72
- def decorate_function_result(self, text: str) -> str:
73
- """Decorate text with the `function_result` template."""
74
- return self.function_result.format(function_result=text)
75
-
76
- def decorate_code_interpreter(self, text: str) -> str:
77
- """Decorate text with the `code_interpreter` template."""
78
- return self.code_interpreter.format(code_interpreter=text)
79
-
80
- def decorate_code_interpreter_call(self, text: str, func: str) -> str:
81
- """Decorate text with the `code_interpreter_call` template."""
82
- return self.code_interpreter_call.format(
83
- assistant=text, code_interpreter_call=func
84
- )
85
-
86
- def decorate_code_interpreter_result(self, text: str) -> str:
87
- """Decorate text with the `code_interpreter_result` template."""
88
- return self.code_interpreter_result.format(code_interpreter_result=text)
89
-
90
- @field_validator("system")
91
- def check_system(cls, v: str) -> str:
92
- """Validate that `system` contains '{system}'.
93
-
94
- If not, raises a ValueError.
95
- """
96
- if v is not None and "{system}" not in v:
97
- raise ValueError("system must contain the keyword '{system}'")
98
- return v
99
-
100
- @field_validator("user")
101
- def check_user(cls, v: str) -> str:
102
- """Validate that `user` contains '{user}'.
103
-
104
- If not, raises a ValueError.
105
- """
106
- if v is not None and "{user}" not in v:
107
- raise ValueError("user must contain the keyword '{user}'")
108
- return v
109
-
110
- @field_validator("assistant")
111
- def check_assistant(cls, v: str) -> str:
112
- """Validate that `assistant` contains '{assistant}'.
113
-
114
- If not, raises a ValueError.
115
- """
116
- if v is not None and "{assistant}" not in v:
117
- raise ValueError("assistant must contain the keyword '{assistant}'")
118
- return v
119
-
120
- @field_validator("function_call")
121
- def check_function_call(cls, v: str) -> str:
122
- """Validate that `function_call` contains '{function_call}'.
123
-
124
- If not, raises a ValueError.
125
- """
126
- if v is not None and "{function_call}" not in v and "{assistant}" not in v:
127
- raise ValueError(
128
- "function_call must contain the keywords '{function_call}'"
129
- )
130
- if v is not None and "{assistant}" not in v:
131
- raise ValueError(
132
- "function_call must contain the keyword '{assistant}' and "
133
- "'{function_call}'"
134
- )
135
- return v
136
-
137
- @field_validator("function_result")
138
- def check_function_result(cls, v: str) -> str:
139
- """Validate that `function_result` contains '{function_result}'.
140
-
141
- If not, raises a ValueError.
142
- """
143
- if v is not None and "{function_result}" not in v:
144
- raise ValueError(
145
- "function_result must contain the keyword '{function_result}'"
146
- )
147
- return v
148
-
149
- @field_validator("functions")
150
- def check_functions(cls, v: str) -> str:
151
- """Validate that `functions` contains '{functions}'.
152
-
153
- If not, raises a ValueError.
154
- """
155
- if v is not None and "{functions}" not in v:
156
- raise ValueError("functions must contain the keyword '{functions}'")
157
- return v
158
-
159
- @field_validator("code_interpreter")
160
- def check_code_interpreter(cls, v: str) -> str:
161
- """Validate that `code_interpreter` contains '{code_interpreter}'.
162
-
163
- If not, raises a ValueError.
164
- """
165
- if v is not None and "{code_interpreter}" not in v:
166
- raise ValueError(
167
- "code_interpreter must contain the keyword " "'{code_interpreter}'"
168
- )
169
- return v
170
-
171
- @field_validator("code_interpreter_call")
172
- def check_code_interpreter_call(cls, v: str) -> str:
173
- """Validate that `code_interpreter_call` contains
174
- '{code_interpreter_call}'.
175
-
176
- If not, raises a ValueError.
177
- """
178
- if (
179
- v is not None
180
- and "{code_interpreter_call}" not in v
181
- and "{assistant}" not in v
182
- ):
183
- raise ValueError(
184
- "code_interpreter_call must contain the keywords "
185
- "'{assistant}' and '{code_interpreter_call}'"
186
- )
187
- if v is not None and "{assistant}" not in v:
188
- raise ValueError(
189
- "code_interpreter_call must contain the keywords "
190
- "'{assistant}' and '{code_interpreter_call}'"
191
- )
192
- return v
193
-
194
- @field_validator("code_interpreter_result")
195
- def check_code_interpreter_result(cls, v: str) -> str:
196
- """Validate that `code_interpreter_result` contains
197
- '{code_interpreter_result}'.
198
-
199
- If not, raises a ValueError.
200
- """
201
- if v is not None and "{code_interpreter_result}" not in v:
202
- raise ValueError(
203
- "code_interpreter_result must contain the keyword "
204
- "'{code_interpreter_result}'"
205
- )
206
- return v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/datasets/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from .json import JsonDataset
3
- from .jsonl import JsonlDataset
4
- from .pack import SoftPackDataset
5
- from .utils import DATASET_CLS_MAP, OPENAI_CONVERT_MAP, load_datasets
6
-
7
- __all__ = [
8
- "JsonDataset",
9
- "JsonlDataset",
10
- "SoftPackDataset",
11
- "DATASET_CLS_MAP",
12
- "OPENAI_CONVERT_MAP",
13
- "load_datasets",
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/datasets/json.py DELETED
@@ -1,177 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import hashlib
3
- import inspect
4
- import json
5
- import math
6
- import os
7
- import random
8
- from concurrent.futures import ProcessPoolExecutor
9
-
10
- import numpy as np
11
- import torch
12
- from mmengine import mkdir_or_exist
13
- from torch import distributed as dist
14
- from tqdm import tqdm
15
-
16
- from xtuner._lite import get_logger
17
-
18
- logger = get_logger()
19
-
20
-
21
- def calculate_json_sha256(file_path):
22
- with open(file_path, "rb") as f:
23
- data = f.read()
24
-
25
- hash_object = hashlib.sha256(data)
26
- hash_hex = hash_object.hexdigest()
27
- return hash_hex
28
-
29
-
30
- def calculate_tokenize_fn_sha256(tokenize_fn):
31
- """Calculate SHA-256 hash for an instance method's source code."""
32
- # Get the source code of the method
33
- fn_source = inspect.getsource(tokenize_fn.__call__)
34
- return hashlib.sha256(fn_source.encode("utf-8")).hexdigest()
35
-
36
-
37
- class JsonDataset(torch.utils.data.Dataset):
38
- def __init__(
39
- self, path, sample_ratio=1.0, tokenize_fn=None, cache_dir=None, max_length=None
40
- ):
41
- super().__init__()
42
-
43
- self.tokenize_fn = tokenize_fn
44
- self.path = path
45
- self.tokenizer_workers = int(os.environ.get("XTUNER_TOKENIZE_WORKERS", 8))
46
-
47
- if cache_dir:
48
- if os.path.exists(cache_dir):
49
- assert os.path.isdir(cache_dir)
50
- else:
51
- mkdir_or_exist(cache_dir)
52
-
53
- file_hash = calculate_json_sha256(path)
54
- file_cache_dir = os.path.join(cache_dir, file_hash)
55
-
56
- if file_hash not in os.listdir(cache_dir):
57
- mkdir_or_exist(file_cache_dir)
58
-
59
- if self.tokenize_fn:
60
- tok_hash = calculate_tokenize_fn_sha256(tokenize_fn)
61
- tok_cache_dir = os.path.join(file_cache_dir, tok_hash)
62
- if tok_hash not in os.listdir(file_cache_dir):
63
- mkdir_or_exist(tok_cache_dir)
64
-
65
- if "num_tokens.npy" in os.listdir(tok_cache_dir):
66
- _cached_file = os.path.join(tok_cache_dir, "num_tokens.npy")
67
- num_tokens = np.load(_cached_file)
68
- else:
69
- num_tokens = self.count_tokens(tok_cache_dir)
70
- else:
71
- num_tokens = None
72
-
73
- else:
74
- num_tokens = None
75
-
76
- with open(self.path) as f:
77
- dataset = json.load(f)
78
-
79
- _sampled = [i for i in range(len(dataset))]
80
-
81
- if max_length is not None:
82
- assert isinstance(max_length, int)
83
- _filtered = [
84
- x for i, x in enumerate(_sampled) if num_tokens[i] < max_length
85
- ]
86
-
87
- if len(_filtered) < len(_sampled):
88
- missed_num = len(_sampled) - len(_filtered)
89
- logger.warning(
90
- f"{path} has {missed_num} prompt length>{max_length}, discard."
91
- )
92
-
93
- _sampled = _filtered
94
-
95
- _target_num_samples = int(len(_sampled) * sample_ratio)
96
- self.sampled = _sampled * int(sample_ratio)
97
- self.sampled.extend(
98
- random.sample(_sampled, _target_num_samples - len(self.sampled))
99
- )
100
-
101
- if num_tokens is not None:
102
- num_tokens = num_tokens[self.sampled]
103
-
104
- self.num_tokens = num_tokens
105
- self.dataset = None
106
-
107
- def count_tokens(self, cache_dir=None):
108
- dataset = []
109
-
110
- with open(self.path) as f:
111
- dataset = json.load(f)
112
-
113
- num_samples = len(dataset)
114
-
115
- if dist.is_available():
116
- world_size = dist.get_world_size()
117
- rank = dist.get_rank()
118
- else:
119
- world_size = 1
120
- rank = 0
121
-
122
- num_per_rank = math.ceil(num_samples / world_size)
123
-
124
- start = rank * num_per_rank
125
- end = (rank + 1) * num_per_rank
126
- dataset_shard = dataset[start:end]
127
-
128
- desc = f"[Rank {rank}] {self.path}"
129
- chunk_size = min(1024, max(1, len(dataset_shard) // self.tokenizer_workers))
130
- with ProcessPoolExecutor(max_workers=self.tokenizer_workers) as executor:
131
- tokenized = list(
132
- tqdm(
133
- executor.map(self.tokenize_fn, dataset_shard, chunksize=chunk_size),
134
- desc=desc,
135
- total=len(dataset_shard),
136
- )
137
- )
138
-
139
- _num_tokens = [data["num_tokens"] for data in tokenized]
140
- _num_tokens = np.array(_num_tokens)
141
-
142
- if dist.is_available():
143
- num_tokens = [None] * world_size
144
- dist.all_gather_object(num_tokens, _num_tokens)
145
- num_tokens = np.concatenate(num_tokens, axis=0)
146
- else:
147
- num_tokens = _num_tokens
148
-
149
- if rank == 0 and cache_dir:
150
- save_path = os.path.join(cache_dir, "num_tokens.npy")
151
- np.save(save_path, num_tokens)
152
-
153
- return num_tokens
154
-
155
- def __len__(self):
156
- return len(self.sampled)
157
-
158
- def __getitem__(self, item):
159
- """Returns a dict containing packed data in the given item.
160
-
161
- Args:
162
- item: An index to retrieve packed data.
163
-
164
- Returns:
165
- A dict including packed input_ids, labels, and cumulative_len.
166
- """
167
- if self.dataset is None:
168
- with open(self.path) as f:
169
- self.dataset = json.load(f)
170
-
171
- raw_data = self.dataset[self.sampled[item]]
172
-
173
- if self.tokenize_fn:
174
- tokenized_data = self.tokenize_fn(raw_data)
175
- return tokenized_data
176
- else:
177
- return raw_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/datasets/jsonl.py DELETED
@@ -1,220 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import hashlib
3
- import json
4
- import math
5
- import multiprocessing
6
- import os
7
- import random
8
- from abc import ABC, abstractmethod
9
- from concurrent.futures import ProcessPoolExecutor
10
- from typing import Any, Callable, TypedDict
11
-
12
- import numpy as np
13
- import torch
14
- from mmengine import mkdir_or_exist
15
- from torch import distributed as dist
16
- from tqdm import tqdm
17
-
18
- from xtuner._lite import get_logger
19
-
20
- logger = get_logger()
21
-
22
-
23
- def calculate_jsonl_sha256(path):
24
- with open(path, "rb") as f:
25
- file_hash = hashlib.sha256()
26
- file_hash.update(f.read())
27
- return file_hash.hexdigest()
28
-
29
-
30
- CacheObj = TypedDict("CachedObj", {"num_tokens": int}, total=False)
31
-
32
-
33
- class CachableTokenizeFunction(ABC):
34
- @abstractmethod
35
- def __call__(self, item: Any) -> CacheObj:
36
- raise NotImplementedError
37
-
38
- @abstractmethod
39
- def hash(self) -> str:
40
- raise NotImplementedError
41
-
42
-
43
- class JsonlDataset(torch.utils.data.Dataset):
44
- def __init__(
45
- self,
46
- path,
47
- sample_ratio: float = 1.0,
48
- tokenize_fn: Callable[[Any], CacheObj] | None = None,
49
- cache_dir: str | None = None,
50
- max_length: int | None = None,
51
- ):
52
- super().__init__()
53
-
54
- self.tokenize_fn = tokenize_fn
55
- self.path = path
56
- self.tokenizer_workers = int(os.environ.get("XTUNER_TOKENIZE_WORKERS", 8))
57
-
58
- if cache_dir and isinstance(tokenize_fn, CachableTokenizeFunction):
59
- if os.path.exists(cache_dir):
60
- assert os.path.isdir(cache_dir)
61
- else:
62
- mkdir_or_exist(cache_dir)
63
-
64
- file_hash = calculate_jsonl_sha256(path)
65
- file_cache_dir = os.path.join(cache_dir, file_hash)
66
-
67
- if file_hash not in os.listdir(cache_dir):
68
- mkdir_or_exist(file_cache_dir)
69
-
70
- if "offsets.npy" in os.listdir(file_cache_dir):
71
- _cached_file = os.path.join(file_cache_dir, "offsets.npy")
72
- offsets = np.load(_cached_file)
73
- else:
74
- offsets = self.count_offsets(file_cache_dir)
75
-
76
- if self.tokenize_fn:
77
- tok_hash = tokenize_fn.hash()
78
- tok_cache_dir = os.path.join(file_cache_dir, tok_hash)
79
- if tok_hash not in os.listdir(file_cache_dir):
80
- mkdir_or_exist(tok_cache_dir)
81
-
82
- if "num_tokens.npy" in os.listdir(tok_cache_dir):
83
- _cached_file = os.path.join(tok_cache_dir, "num_tokens.npy")
84
- num_tokens = np.load(_cached_file)
85
- else:
86
- num_tokens = self.count_tokens(offsets, tok_cache_dir)
87
- else:
88
- num_tokens = None
89
-
90
- offsets = offsets
91
- num_tokens = num_tokens
92
-
93
- else:
94
- offsets = self.count_offsets()
95
- num_tokens = None
96
- if max_length is not None:
97
- assert self.tokenize_fn
98
- num_tokens = self.count_tokens(offsets)
99
-
100
- _sampled = [i for i in range(len(offsets))]
101
-
102
- if max_length is not None:
103
- assert isinstance(max_length, int)
104
- _filtered = [
105
- x for i, x in enumerate(_sampled) if num_tokens[i] < max_length
106
- ]
107
-
108
- if len(_filtered) < len(_sampled):
109
- missed_num = len(_sampled) - len(_filtered)
110
- logger.warning(
111
- f"{path} has {missed_num} prompt length>{max_length}, discard."
112
- )
113
-
114
- _sampled = _filtered
115
-
116
- _target_num_samples = int(len(_sampled) * sample_ratio)
117
- self.sampled = _sampled * int(sample_ratio)
118
- self.sampled.extend(
119
- random.sample(_sampled, _target_num_samples - len(self.sampled))
120
- )
121
-
122
- if num_tokens is not None:
123
- num_tokens = num_tokens[self.sampled]
124
-
125
- self.num_tokens = num_tokens
126
- self.offsets = offsets[self.sampled]
127
-
128
- def count_offsets(self, cache_dir=None):
129
- offsets = [0]
130
- with open(self.path) as f:
131
- lines = f.readlines()
132
- for line in lines[:-1]:
133
- offsets.append(offsets[-1] + len(line.encode()))
134
-
135
- offsets = np.array(offsets)
136
-
137
- if dist.get_rank() == 0 and cache_dir:
138
- save_path = os.path.join(cache_dir, "offsets.npy")
139
- np.save(save_path, offsets)
140
-
141
- return offsets
142
-
143
- def _tokenize_by_offset(self, offset):
144
- with open(self.path) as f:
145
- f.seek(offset)
146
- data = json.loads(f.readline())
147
- return self.tokenize_fn(data)
148
-
149
- def count_tokens(self, offsets, cache_dir=None):
150
- num_samples = len(offsets)
151
-
152
- if dist.is_available():
153
- world_size = dist.get_world_size()
154
- rank = dist.get_rank()
155
- else:
156
- world_size = 1
157
- rank = 0
158
-
159
- num_per_rank = math.ceil(num_samples / world_size)
160
-
161
- start = rank * num_per_rank
162
- end = (rank + 1) * num_per_rank
163
- offsets_shard = offsets[start:end]
164
-
165
- desc = f"[Rank {rank}] {self.path}"
166
- chunk_size = min(1024, max(1, len(offsets_shard) // self.tokenizer_workers))
167
-
168
- mp_context = multiprocessing.get_context("fork")
169
- with ProcessPoolExecutor(
170
- max_workers=self.tokenizer_workers, mp_context=mp_context
171
- ) as executor:
172
- tokenized = list(
173
- tqdm(
174
- executor.map(
175
- self._tokenize_by_offset, offsets_shard, chunksize=chunk_size
176
- ),
177
- desc=desc,
178
- total=len(offsets_shard),
179
- )
180
- )
181
-
182
- _num_tokens = [data["num_tokens"] for data in tokenized]
183
- _num_tokens = np.array(_num_tokens)
184
-
185
- if dist.is_available():
186
- num_tokens = [None] * world_size
187
- dist.all_gather_object(num_tokens, _num_tokens)
188
- num_tokens = np.concatenate(num_tokens, axis=0)
189
- else:
190
- num_tokens = _num_tokens
191
-
192
- if rank == 0 and cache_dir:
193
- save_path = os.path.join(cache_dir, "num_tokens.npy")
194
- np.save(save_path, num_tokens)
195
-
196
- return num_tokens
197
-
198
- def __len__(self):
199
- return len(self.offsets)
200
-
201
- def __getitem__(self, item):
202
- """Returns a dict containing packed data in the given item.
203
-
204
- Args:
205
- item: An index to retrieve packed data.
206
-
207
- Returns:
208
- A dict including packed input_ids, labels, and cumulative_len.
209
- """
210
- with open(self.path) as f:
211
- f.seek(self.offsets[item])
212
- line = f.readline()
213
-
214
- raw_data = json.loads(line)
215
-
216
- if self.tokenize_fn:
217
- tokenized_data = self.tokenize_fn(raw_data)
218
- return tokenized_data
219
- else:
220
- return raw_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/datasets/pack.py DELETED
@@ -1,257 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import bisect
3
- import itertools
4
- import random
5
-
6
- import numpy as np
7
- import torch
8
- from datasets import Dataset, concatenate_datasets
9
- from torch.utils.data import ConcatDataset
10
-
11
-
12
- class SoftPackDataset(torch.utils.data.Dataset):
13
- def __init__(self, datasets, target=2048, blend=False, sort=False):
14
- if blend:
15
- num_tokens = [np.concatenate([dset.num_tokens for dset in datasets])]
16
- datasets = [ConcatDataset(datasets)]
17
- else:
18
- num_tokens = [dset.num_tokens for dset in datasets]
19
- self.datasets = datasets
20
- self.target = target
21
-
22
- pack_infos = []
23
- for i, dataset in enumerate(self.datasets):
24
- _infos = self.get_pack_infos(dataset, i, num_tokens[i])
25
- pack_infos.append(_infos)
26
- self.pack_infos = concatenate_datasets(pack_infos)
27
-
28
- @property
29
- def longest(self):
30
- return self.pack_infos["longest"]
31
-
32
- def get_pack_infos(self, dataset, dataset_id, num_tokens):
33
- # _ori_lens = dataset['num_tokens']
34
- inds = [i for i in range(len(dataset))]
35
- random.shuffle(inds)
36
-
37
- item_buffer = []
38
- length_buffer = []
39
- longest = 0
40
-
41
- pack_infos = []
42
- for shfl_i in inds:
43
- if num_tokens[shfl_i] + sum(length_buffer) <= self.target:
44
- item_buffer.append(shfl_i)
45
- length_buffer.append(num_tokens[shfl_i])
46
- longest = max(longest, num_tokens[shfl_i])
47
- else:
48
- if len(item_buffer) > 0:
49
- info = {
50
- "dataset_id": dataset_id,
51
- "indices": item_buffer,
52
- "longest": int(longest),
53
- }
54
- pack_infos.append(info)
55
-
56
- item_buffer = [shfl_i]
57
- length_buffer = [num_tokens[shfl_i]]
58
- longest = num_tokens[shfl_i]
59
-
60
- if len(item_buffer) > 0:
61
- info = {
62
- "dataset_id": dataset_id,
63
- "indices": item_buffer,
64
- "longest": int(longest),
65
- }
66
-
67
- pack_infos.append(info)
68
-
69
- pack_infos = Dataset.from_list(pack_infos)
70
-
71
- return pack_infos
72
-
73
- def __len__(self):
74
- return len(self.pack_infos)
75
-
76
- def __getitem__(self, item):
77
- indices = self.pack_infos[item]["indices"]
78
- dataset_id = self.pack_infos[item]["dataset_id"]
79
- return [self.datasets[dataset_id][i] for i in indices]
80
-
81
-
82
- class HardPackDataset(torch.utils.data.Dataset):
83
- def __init__(self, datasets, target=2048, blend=True, sort=False):
84
- if blend:
85
- num_tokens = [np.concatenate([dset.num_tokens for dset in datasets])]
86
- datasets = [ConcatDataset(datasets)]
87
- else:
88
- num_tokens = [dset.num_tokens for dset in datasets]
89
- self.datasets = datasets
90
- self.target = target
91
-
92
- pack_infos = []
93
- for i, dataset in enumerate(self.datasets):
94
- _info = self.get_pack_info(dataset, i, num_tokens[i])
95
- pack_infos.append(_info)
96
-
97
- _ranges_left = []
98
- _ranges_right = []
99
- _num_packed_samples = []
100
- _indices = []
101
- _max_length_per_pack = []
102
- _dataset_id = []
103
- for info in pack_infos:
104
- _ranges_left.extend(info["ranges_left"])
105
- _ranges_right.extend(info["ranges_right"])
106
- _num_packed_samples.append(info["num_packed_samples"])
107
- _indices.extend(info["indices"])
108
- _max_length_per_pack.extend(info["max_length_per_pack"])
109
- _dataset_id.extend(info["dataset_id"])
110
-
111
- self.pack_infos = {
112
- "ranges_left": _ranges_left,
113
- "ranges_right": _ranges_right,
114
- "num_packed_samples": _num_packed_samples,
115
- "indices": _indices,
116
- "max_length_per_pack": _max_length_per_pack,
117
- "dataset_id": _dataset_id,
118
- }
119
-
120
- @classmethod
121
- def _cal_max_length(cls, begin, end, shfl_item_rngs_left, shfl_item_rngs_right):
122
- left = bisect.bisect(shfl_item_rngs_right, begin)
123
- right = bisect.bisect(shfl_item_rngs_left, end)
124
- max_length = 0
125
- for i in range(left, right):
126
- item_begin = shfl_item_rngs_left[i]
127
- item_end = shfl_item_rngs_right[i]
128
- inner_l = max(begin, item_begin) - item_begin
129
- inner_r = min(end, item_end) - item_begin
130
- trunc_size = inner_r - inner_l
131
- max_length = max(max_length, trunc_size)
132
- return max_length
133
-
134
- def get_pack_info(self, dataset, dataset_id, num_tokens):
135
- # The number of data items after packing
136
- num_packed_samples = int(num_tokens.sum() / self.target)
137
-
138
- # Shuffle the order of the original dataset
139
- # The packing will proceed according to the order after shuffle.
140
- # Assume the following conditions hold:
141
- # (1) shfl_inds = [3, 1, 2, 0]
142
- # (2) self._ori_lens[3] + self._ori_lens[1] = max_length
143
- # (3) self._ori_lens[2] + self._ori_lens[0] = max_length
144
- # Ultimately, dataset[3] and dataset[1] will be combined into a new
145
- # data, and dataset[2] and dataset[0] will be combined into a new data.
146
- inds = [i for i in range(len(dataset))]
147
- # if seed is not None:
148
- # random.seed(seed)
149
- random.shuffle(inds)
150
- shfl_inds = inds
151
-
152
- # shuffled cumulative lengths
153
- shfl_lens = [num_tokens[i] for i in shfl_inds]
154
- shfl_acc_lens = list(itertools.accumulate(shfl_lens))
155
-
156
- shfl_item_rngs_left = [0] + shfl_acc_lens[:-1]
157
- shfl_item_rngs_right = shfl_acc_lens
158
-
159
- max_length_per_pack = []
160
- belong_dataset_ids = []
161
- for i in range(num_packed_samples):
162
- begin = i * self.target
163
- end = (i + 1) * self.target
164
- max_length_per_pack.append(
165
- self._cal_max_length(
166
- begin, end, shfl_item_rngs_left, shfl_item_rngs_right
167
- )
168
- )
169
- belong_dataset_ids.append(dataset_id)
170
-
171
- pack_infos = {
172
- "ranges_left": shfl_item_rngs_left,
173
- "ranges_right": shfl_item_rngs_right,
174
- "num_packed_samples": num_packed_samples,
175
- "indices": shfl_inds,
176
- "dataset_id": belong_dataset_ids,
177
- "max_length_per_pack": max_length_per_pack,
178
- }
179
-
180
- # pack_infos = Dataset.from_list(pack_infos)
181
-
182
- return pack_infos
183
-
184
- def _pack_ids_and_labels_in_range(self, begin: int, end: int):
185
- """Packs ids and labels in a given range using bisection method.
186
-
187
- Args:
188
- begin: Index indicating the beginning of the range.
189
- end: Index indicating the end of the range.
190
-
191
- Returns:
192
- A tuple containing packed ids, labels, and cumulative lengths.
193
- """
194
-
195
- # Use binary search to find dataset positions that fall within begin
196
- # and end range
197
- left = bisect.bisect(self.pack_infos["ranges_left"], begin)
198
- right = bisect.bisect(self.pack_infos["ranges_right"], end)
199
-
200
- trunc_input_ids = []
201
- trunc_labels = []
202
- trunc_sizes = []
203
-
204
- for i in range(left, right):
205
- # Determine the real range we will cut in current original item
206
- item_begin = self.pack_infos["ranges_left"][i]
207
- item_end = self.pack_infos["ranges_right"][i]
208
-
209
- # Calculate exact positions within current dataset item
210
- inner_l = max(begin, item_begin) - item_begin
211
- inner_r = min(end, item_end) - item_begin
212
-
213
- # Get original data and labels
214
- ori_idx = self.pack_infos["indices"][i]
215
- ori_dataset_id = self.pack_infos["dataset_id"][i]
216
- ori_input_ids = self.datasets[ori_dataset_id][ori_idx]["input_ids"]
217
- ori_labels = self.datasets[ori_dataset_id][ori_idx]["labels"]
218
-
219
- # Add original data and labels from calculated positions
220
- # to trunc_ids and trunc_labels
221
- trunc_input_ids.extend(ori_input_ids[inner_l:inner_r])
222
- trunc_labels.extend(ori_labels[inner_l:inner_r])
223
- trunc_sizes.append(inner_r - inner_l)
224
-
225
- # return populated lists of truncated ids, labels and their cumulative
226
- # lengths
227
- return trunc_input_ids, trunc_labels, trunc_sizes
228
-
229
- def __len__(self):
230
- return len(self.pack_infos["indices"])
231
-
232
- def __getitem__(self, item):
233
- """Returns a dict containing packed data in the given item.
234
-
235
- Args:
236
- item: An index to retrieve packed data.
237
-
238
- Returns:
239
- A dict including packed input_ids, labels, and cumulative_len.
240
- """
241
- # The cumulative length from the start position of this data
242
- begin = item * self.target
243
- # The cumulative length from the end position of this data
244
- end = (item + 1) * self.target
245
-
246
- # Extract data within the range from the shuffled original dataset.
247
- _res = self._pack_ids_and_labels_in_range(begin, end)
248
- packed_input_ids, packed_labels, num_tokens = _res
249
- assert self.target == len(packed_input_ids) == len(packed_labels)
250
-
251
- packed = {
252
- "input_ids": packed_input_ids,
253
- "labels": packed_labels,
254
- "num_tokens": num_tokens,
255
- }
256
-
257
- return packed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/datasets/streaming.py DELETED
@@ -1,28 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
-
3
-
4
- class Streaming:
5
- def __init__(self, file, max_epoch=1):
6
- self.file = file
7
- self.offset = 0
8
- self.epoch = 1
9
- self.max_epoch = max_epoch
10
-
11
- def __iter__(self):
12
- return self
13
-
14
- def __next__(self):
15
- with open(self.file) as f:
16
- f.seek(self.offset)
17
- line = f.readline()
18
-
19
- if not line and self.epoch < self.max_epoch:
20
- self.offset = 0
21
- self.epoch += 1
22
- return next(self)
23
-
24
- elif not line and self.epoch == self.max_epoch:
25
- raise StopIteration
26
-
27
- self.offset = f.tell()
28
- return line
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/datasets/utils/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from .convert import OPENAI_CONVERT_MAP
3
- from .load import DATASET_CLS_MAP, load_datasets
4
- from .utils import apply_exif_orientation, move_data_to_device
5
-
6
- __all__ = [
7
- "OPENAI_CONVERT_MAP",
8
- "DATASET_CLS_MAP",
9
- "load_datasets",
10
- "apply_exif_orientation",
11
- "move_data_to_device",
12
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/datasets/utils/convert.py DELETED
@@ -1,195 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import re
3
-
4
- from xtuner._lite.chat import ChatMessages
5
-
6
-
7
- class XTunerFormat2Openai:
8
- @classmethod
9
- def source_format(cls):
10
- data = {
11
- "conversation": [
12
- {"system": "SYSTEM", "input": "INPUT", "output": "OUTPUT"},
13
- {"input": "INPUT", "output": "OUTPUT"},
14
- ]
15
- }
16
- return data
17
-
18
- @classmethod
19
- def target_format(cls):
20
- data = {
21
- "messages": [
22
- {"role": "system", "content": "SYSTEM"},
23
- {"role": "user", "content": "INPUT"},
24
- {"role": "assistant", "content": "OUTPUT"},
25
- {"role": "user", "content": "INPUT"},
26
- {"role": "assistant", "content": "OUTPUT"},
27
- ]
28
- }
29
- return data
30
-
31
- @staticmethod
32
- def convert(data):
33
- ROLE_MAPPING = {"system": "system", "input": "user", "output": "assistant"}
34
- messages = []
35
- for single_turn_conversation in data["conversation"]:
36
- for role, content in single_turn_conversation.items():
37
- messages.append({"role": ROLE_MAPPING[role], "content": content})
38
- return ChatMessages.from_dict({"messages": messages})
39
-
40
-
41
- class Alpaca2Openai:
42
- @classmethod
43
- def source_format(cls):
44
- data = {
45
- "instruction": "INSTRUCTION",
46
- "input": "INPUT",
47
- "output": "OUTPUT",
48
- }
49
- return data
50
-
51
- @classmethod
52
- def target_format(cls):
53
- data = {
54
- "messages": [
55
- {"role": "user", "content": "INSTRUCTION\nINPUT"},
56
- {"role": "assistant", "content": "OUTPUT"},
57
- ]
58
- }
59
- return data
60
-
61
- @staticmethod
62
- def convert(data):
63
- if data.get("output") == "<nooutput>":
64
- return ChatMessages.from_dict({"messages": []})
65
- else:
66
- return ChatMessages.from_dict(
67
- {
68
- "messages": [
69
- {
70
- "role": "user",
71
- "content": f"{data['instruction']}\n{data['input']}",
72
- },
73
- {"role": "assistant", "content": f"{data['output']}"},
74
- ]
75
- }
76
- )
77
-
78
-
79
- def llava_to_openai(data):
80
- image_token = "<image>"
81
- conversations = data["conversations"]
82
- messages = []
83
-
84
- if "image" in data:
85
- image_urls = data["image"]
86
- if isinstance(image_urls, str):
87
- image_urls = [image_urls]
88
- else:
89
- image_urls = None
90
-
91
- while conversations and conversations[0]["from"] == "gpt":
92
- # Skip the first one if it is from gpt
93
- conversations = conversations[1:]
94
-
95
- image_id = 0
96
- for convs in conversations:
97
- if convs["from"] == "human":
98
- pattern = f"({image_token})"
99
- chunks = re.split(pattern, convs["value"])
100
-
101
- text_content = []
102
- img_content = []
103
-
104
- for chunk in chunks:
105
- if chunk == image_token:
106
- url = image_urls[image_id]
107
- if not isinstance(url, str):
108
- raise TypeError(data)
109
- # assert , image_url
110
- item = dict(type="image_url", image_url=url)
111
- img_content.append(item)
112
- image_id += 1
113
- elif len(chunk.strip()):
114
- item = dict(type="text", text=chunk.strip())
115
- text_content.append(item)
116
-
117
- msg = {"role": "user", "content": img_content + text_content}
118
- messages.append(msg)
119
-
120
- elif convs["from"] == "gpt":
121
- msg = {"role": "assistant", "content": convs["value"]}
122
- messages.append(msg)
123
- else:
124
- raise NotImplementedError
125
-
126
- return ChatMessages.from_dict({"messages": messages})
127
-
128
-
129
- def llava_to_openai_interleave(data):
130
- image_token = "<image>"
131
- conversations = data["conversations"]
132
- messages = []
133
-
134
- if "image" in data:
135
- image_urls = data["image"]
136
- if isinstance(image_urls, str):
137
- image_urls = [image_urls]
138
- else:
139
- image_urls = None
140
-
141
- while conversations and conversations[0]["from"] == "gpt":
142
- # Skip the first one if it is from gpt
143
- conversations = conversations[1:]
144
-
145
- image_id = 0
146
- for convs in conversations:
147
- if convs["from"] == "human":
148
- pattern = f"({image_token})"
149
- chunks = re.split(pattern, convs["value"])
150
-
151
- content = []
152
-
153
- for chunk in chunks:
154
- if chunk == image_token:
155
- url = image_urls[image_id]
156
- if not isinstance(url, str):
157
- raise TypeError(data)
158
- # assert , image_url
159
- item = dict(type="image_url", image_url=url)
160
- content.append(item)
161
- image_id += 1
162
- elif len(chunk.strip()):
163
- item = dict(type="text", text=chunk.strip())
164
- content.append(item)
165
-
166
- msg = {"role": "user", "content": content}
167
- messages.append(msg)
168
-
169
- elif convs["from"] == "gpt":
170
- msg = {"role": "assistant", "content": convs["value"]}
171
- messages.append(msg)
172
- else:
173
- raise NotImplementedError
174
-
175
- return ChatMessages.from_dict({"messages": messages})
176
-
177
-
178
- def official_openai(data):
179
- if "messages" in data:
180
- return ChatMessages.from_dict(data)
181
- elif "message_data" in data:
182
- return ChatMessages.from_dict({"messages": data["message_data"]})
183
- elif "dialogs" in data:
184
- return ChatMessages.from_dict({"messages": data["dialogs"]})
185
- else:
186
- return ChatMessages.from_dict({"messages": data})
187
-
188
-
189
- OPENAI_CONVERT_MAP = {
190
- "llava": llava_to_openai,
191
- "llava_interleave": llava_to_openai_interleave,
192
- "alpaca": Alpaca2Openai.convert,
193
- "xtuner": XTunerFormat2Openai.convert,
194
- "openai": official_openai,
195
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/datasets/utils/load.py DELETED
@@ -1,286 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import json
3
- import math
4
- import os
5
- import random
6
- import re
7
-
8
- from torch import distributed as dist
9
- from tqdm import tqdm
10
-
11
- from xtuner._lite import get_logger
12
-
13
- from ..json import JsonDataset
14
- from ..jsonl import JsonlDataset
15
-
16
- logger = get_logger()
17
-
18
- DATASET_CLS_MAP = {".jsonl": JsonlDataset, ".json": JsonDataset}
19
-
20
-
21
- def load_hf_dataset(path, split="train", sample_ratio=1.0, cache_dir=None, map_fn=None):
22
- from datasets import load_dataset
23
-
24
- dataset = load_dataset(path)[split]
25
-
26
- if map_fn:
27
- dataset = dataset.map(map_fn, num_proc=8)
28
-
29
- if sample_ratio != 1:
30
- ori_samples = len(dataset)
31
- target_samples = int(sample_ratio * ori_samples)
32
- indices = random.choices([i for i in range(ori_samples)], k=target_samples)
33
- dataset = dataset.select(indices)
34
-
35
- dataset = dataset.to_list()
36
-
37
- # if init_fn:
38
- # dataset = init_fn(dataset)
39
-
40
- # if cache_dir and isinstance(dataset, CacheDataset):
41
- # dataset.cache(cache_dir)
42
-
43
- return dataset
44
-
45
-
46
- def load_from_cache(cache_dir, init_fn):
47
- if dist.is_available():
48
- world_size = dist.get_world_size()
49
- rank = dist.get_rank()
50
- else:
51
- world_size = 1
52
- rank = 0
53
-
54
- sub_cache_dirs = []
55
- for _path in tqdm(os.listdir(cache_dir)):
56
- path = os.path.join(cache_dir, _path)
57
- if os.path.isdir(path):
58
- sub_cache_dirs.append(path)
59
-
60
- num_dsets = len(sub_cache_dirs)
61
- avg_num = math.ceil(num_dsets / world_size)
62
- start = rank * avg_num
63
- end = min((rank + 1) * avg_num, num_dsets)
64
- desc = f"[Rank {rank}] Loading Cached Dataset"
65
-
66
- rank_datasets = []
67
- for ind in tqdm(range(start, end), desc=desc):
68
- dset = init_fn(sub_cache_dirs[ind])
69
- rank_datasets.append(dset)
70
-
71
- if dist.is_available() and world_size > 1:
72
- dist.barrier()
73
- buffers = [None] * world_size
74
- dist.all_gather_object(buffers, rank_datasets)
75
- world_datasets = []
76
- for dsets_per_rank in buffers:
77
- world_datasets.extend(dsets_per_rank)
78
-
79
- assert len(world_datasets) == num_dsets
80
- else:
81
- world_datasets = rank_datasets
82
-
83
- return world_datasets
84
-
85
-
86
- def load_local_datasets(
87
- paths,
88
- file_types,
89
- file_pattern=None,
90
- cache_dir=None,
91
- sample_ratios=1.0,
92
- map_fns=None,
93
- max_length=None,
94
- ):
95
- if isinstance(paths, str):
96
- paths = [paths]
97
-
98
- if isinstance(sample_ratios, (tuple, list)):
99
- if len(sample_ratios) == 1:
100
- sample_ratios = list(sample_ratios) * len(paths)
101
-
102
- if len(sample_ratios) != len(paths):
103
- raise RuntimeError(
104
- f"There are {len(paths)} paths, but only "
105
- f"{len(sample_ratios)} sample ratios were set."
106
- )
107
-
108
- if map_fns is None:
109
- map_fns = [None] * len(paths)
110
-
111
- if isinstance(map_fns, (tuple, list)):
112
- if len(map_fns) == 1:
113
- map_fns = list(map_fns) * len(paths)
114
-
115
- if len(map_fns) != len(paths):
116
- raise RuntimeError(
117
- f"There are {len(paths)} paths, but only"
118
- f"{len(map_fns)} map fns were set."
119
- )
120
-
121
- files = []
122
- file_sample_ratios = []
123
- file_map_fns = []
124
-
125
- for pid, path in enumerate(paths):
126
- if os.path.isdir(path):
127
- dir_files = []
128
- for root, dirs, _files in os.walk(path, followlinks=True):
129
- dirs.sort()
130
- for relative_path in sorted(_files):
131
- suffix = os.path.splitext(relative_path)[-1]
132
- absolute_path = os.path.join(root, relative_path)
133
- if file_pattern is not None:
134
- if bool(re.match(file_pattern, absolute_path)):
135
- dir_files.append(absolute_path)
136
- elif suffix in file_types:
137
- dir_files.append(absolute_path)
138
-
139
- _num_dir_files = len(dir_files)
140
- if _num_dir_files == 0:
141
- raise RuntimeError(
142
- f"There are no files with the suffix {file_types}" f"in `{path}`."
143
- )
144
-
145
- logger.info(f"Found {len(dir_files)} files in {path}")
146
- files.extend(dir_files)
147
- file_sample_ratios.extend([sample_ratios[pid]] * _num_dir_files)
148
- file_map_fns.extend([map_fns[pid]] * _num_dir_files)
149
-
150
- elif os.path.isfile(path):
151
- files.append(path)
152
- file_sample_ratios.append(sample_ratios[pid])
153
- file_map_fns.append(map_fns[pid])
154
-
155
- else:
156
- raise RuntimeError(f"`{path}` not found.")
157
-
158
- num_files = len(files)
159
-
160
- datasets = []
161
- for i in range(num_files):
162
- _path = files[i]
163
- _ratio = file_sample_ratios[i]
164
- _map_fn = file_map_fns[i]
165
- _suffix = os.path.splitext(_path)[-1]
166
-
167
- dataset_cls = DATASET_CLS_MAP[_suffix]
168
- _dataset = dataset_cls(_path, _ratio, _map_fn, cache_dir, max_length)
169
- datasets.append(_dataset)
170
-
171
- return datasets
172
-
173
-
174
- def load_datasets(
175
- paths,
176
- sources="local",
177
- sample_ratios=1.0,
178
- file_types=DATASET_CLS_MAP.keys(),
179
- file_pattern=None,
180
- cache_dir=None,
181
- map_fns=None,
182
- max_length=None,
183
- ):
184
- if isinstance(paths, str):
185
- paths = [paths]
186
-
187
- num_paths = len(paths)
188
-
189
- if isinstance(sample_ratios, (float, int)):
190
- sample_ratios = [sample_ratios] * num_paths
191
-
192
- if isinstance(sample_ratios, (tuple, list)):
193
- if len(sample_ratios) == 1:
194
- sample_ratios = list(sample_ratios) * num_paths
195
-
196
- if len(sample_ratios) != num_paths:
197
- raise RuntimeError(
198
- f"There are {num_paths} paths, but only "
199
- f"{len(sample_ratios)} sample ratios were set."
200
- )
201
-
202
- if isinstance(sources, str):
203
- sources = [sources]
204
-
205
- if isinstance(sources, (tuple, list)):
206
- if len(sources) == 1:
207
- sources = list(sources) * num_paths
208
-
209
- if len(sources) != num_paths:
210
- raise RuntimeError(
211
- f"There are {num_paths} paths, but only "
212
- f"{len(sources)} sources were set."
213
- )
214
-
215
- if not isinstance(map_fns, (tuple, list)):
216
- map_fns = [map_fns] * num_paths
217
-
218
- if isinstance(map_fns, (tuple, list)):
219
- if len(map_fns) == 1:
220
- map_fns = list(map_fns) * num_paths
221
-
222
- if len(map_fns) != num_paths:
223
- raise RuntimeError(
224
- f"There are {num_paths} paths, but only"
225
- f"{len(map_fns)} map fns were set."
226
- )
227
-
228
- local_inds = [i for i, src in enumerate(sources) if src == "local"]
229
- local_paths = [paths[ind] for ind in local_inds]
230
- local_map_fns = [map_fns[ind] for ind in local_inds]
231
- local_sample_ratios = [sample_ratios[ind] for ind in local_inds]
232
-
233
- hf_inds = [i for i, src in enumerate(sources) if src == "huggingface"]
234
- hf_paths = [paths[ind] for ind in hf_inds]
235
- hf_map_fns = [map_fns[ind] for ind in hf_inds]
236
- hf_sample_ratios = [sample_ratios[ind] for ind in hf_inds]
237
-
238
- datasets = []
239
- if len(local_inds):
240
- local_datasets = load_local_datasets(
241
- local_paths,
242
- file_types,
243
- file_pattern,
244
- cache_dir,
245
- local_sample_ratios,
246
- local_map_fns,
247
- max_length,
248
- )
249
- datasets.extend(local_datasets)
250
-
251
- if len(hf_inds):
252
- cached_infos = {}
253
- for i in range(len(hf_inds)):
254
- if cache_dir:
255
- digits = len(str(abs(len(hf_inds))))
256
- cache_id = f"cache-hf-{i+1:0{digits}}-of-" f"{len(hf_inds):0{digits}}"
257
- sub_cache_dir = os.path.join(cache_dir, cache_id)
258
- else:
259
- sub_cache_dir = None
260
- dset = load_hf_dataset(
261
- hf_paths[i],
262
- sample_ratio=hf_sample_ratios[i],
263
- map_fn=hf_map_fns[i],
264
- cache_dir=sub_cache_dir,
265
- max_length=max_length,
266
- )
267
- datasets.append(dset)
268
- breakpoint()
269
- if cache_dir:
270
- infos = {
271
- "path": hf_paths[i],
272
- "num_samples": dset.num_samples,
273
- "num_tokens": dset.total_tokens,
274
- }
275
- cached_infos[cache_id] = infos
276
-
277
- if cache_dir:
278
- _path = os.path.join(cache_dir, "hf_infos.json")
279
- with open(_path, "w") as f:
280
- json.dump(cached_infos, f)
281
-
282
- return datasets
283
-
284
-
285
- def load_ms_dataset():
286
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/datasets/utils/utils.py DELETED
@@ -1,66 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from collections.abc import Mapping
3
-
4
- import torch
5
- from PIL import Image
6
-
7
- _EXIF_ORIENT = 274 # exif 'Orientation' tag
8
-
9
-
10
- def apply_exif_orientation(image):
11
- """Applies the exif orientation correctly.
12
-
13
- This code exists per the bug:
14
- https://github.com/python-pillow/Pillow/issues/3973
15
- with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
16
- various methods, especially `tobytes`
17
-
18
- Function based on:
19
- https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
20
- https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
21
-
22
- Args:
23
- image (PIL.Image): a PIL image
24
-
25
- Returns:
26
- (PIL.Image): the PIL image with exif orientation applied, if applicable
27
- """
28
- if not hasattr(image, "getexif"):
29
- return image
30
-
31
- try:
32
- exif = image.getexif()
33
- except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
34
- exif = None
35
-
36
- if exif is None:
37
- return image
38
-
39
- orientation = exif.get(_EXIF_ORIENT)
40
-
41
- method = {
42
- 2: Image.FLIP_LEFT_RIGHT,
43
- 3: Image.ROTATE_180,
44
- 4: Image.FLIP_TOP_BOTTOM,
45
- 5: Image.TRANSPOSE,
46
- 6: Image.ROTATE_270,
47
- 7: Image.TRANSVERSE,
48
- 8: Image.ROTATE_90,
49
- }.get(orientation)
50
-
51
- if method is not None:
52
- return image.transpose(method)
53
- return image
54
-
55
-
56
- def move_data_to_device(data, device="cuda"):
57
- """Prepares one `data` before feeding it to the model, be it a tensor or a
58
- nested list/dictionary of tensors."""
59
- if isinstance(data, Mapping):
60
- return type(data)({k: move_data_to_device(v) for k, v in data.items()})
61
- elif isinstance(data, (tuple, list)):
62
- return type(data)(move_data_to_device(v) for v in data)
63
- elif isinstance(data, torch.Tensor):
64
- kwargs = {"device": device}
65
- return data.to(non_blocking=True, **kwargs)
66
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/device.py DELETED
@@ -1,42 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import torch
3
-
4
-
5
- def get_device():
6
- device = None
7
- if torch.cuda.is_available():
8
- device = "cuda"
9
- else:
10
- try:
11
- import torch_npu # noqa: F401
12
-
13
- device = "npu"
14
- except ImportError:
15
- pass
16
- try:
17
- import torch_mlu # noqa: F401
18
-
19
- device = "mlu"
20
- except ImportError:
21
- pass
22
-
23
- if device is None:
24
- raise NotImplementedError(
25
- "Supports only CUDA or NPU. If your device is CUDA or NPU, "
26
- "please make sure that your environmental settings are "
27
- "configured correctly."
28
- )
29
-
30
- return device
31
-
32
-
33
- def get_torch_device_module():
34
- device = get_device()
35
- if device == "cuda":
36
- return torch.cuda
37
- elif device == "npu":
38
- return torch.npu
39
- elif device == "mlu":
40
- return torch.mlu
41
- else:
42
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/modelings/.DS_Store DELETED
Binary file (6.15 kB)
 
code/xtuner/_lite/modelings/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .internlm2 import InternLM2Config, InternLM2ForCausalLM
2
- from .internlm3 import InternLM3Config, InternLM3ForCausalLM, InternLM3Tokenizer
3
- from .llava.modeling_llava import LlavaForConditionalGeneration
4
- from .llava.configuration_llava import EnhancedLlavaConfig
5
- from .llava.processing_llava import LlavaProcessor
6
-
7
- def register_remote_code():
8
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
9
- AutoConfig.register('internlm2', InternLM2Config, exist_ok=True)
10
- AutoModelForCausalLM.register(
11
- InternLM2Config, InternLM2ForCausalLM, exist_ok=True)
12
-
13
- AutoConfig.register('internlm3', InternLM3Config, exist_ok=True)
14
- AutoModelForCausalLM.register(
15
- InternLM3Config, InternLM3ForCausalLM, exist_ok=True)
16
- AutoTokenizer.register(
17
- InternLM3Config, InternLM3Tokenizer, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/modelings/internlm2/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .configuration_internlm2 import InternLM2Config
2
- from .modeling_internlm2 import InternLM2ForCausalLM
 
 
 
code/xtuner/_lite/modelings/internlm2/configuration_internlm2.py DELETED
@@ -1,175 +0,0 @@
1
- # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """ InternLM2 model configuration"""
17
-
18
- from transformers.configuration_utils import PretrainedConfig
19
- from transformers.utils import logging
20
-
21
- logger = logging.get_logger(__name__)
22
-
23
- INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24
-
25
-
26
- # Modified from transformers.model.llama.configuration_llama.LlamaConfig
27
- class InternLM2Config(PretrainedConfig):
28
- r"""
29
- This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
30
- an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
31
- configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
32
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
- documentation from [`PretrainedConfig`] for more information.
34
- Args:
35
- vocab_size (`int`, *optional*, defaults to 32000):
36
- Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
37
- `inputs_ids` passed when calling [`InternLM2Model`]
38
- hidden_size (`int`, *optional*, defaults to 4096):
39
- Dimension of the hidden representations.
40
- intermediate_size (`int`, *optional*, defaults to 11008):
41
- Dimension of the MLP representations.
42
- num_hidden_layers (`int`, *optional*, defaults to 32):
43
- Number of hidden layers in the Transformer decoder.
44
- num_attention_heads (`int`, *optional*, defaults to 32):
45
- Number of attention heads for each attention layer in the Transformer decoder.
46
- num_key_value_heads (`int`, *optional*):
47
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
48
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
49
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
50
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
51
- by meanpooling all the original heads within that group. For more details checkout [this
52
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
53
- `num_attention_heads`.
54
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
55
- The non-linear activation function (function or string) in the decoder.
56
- max_position_embeddings (`int`, *optional*, defaults to 2048):
57
- The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens.
58
- initializer_range (`float`, *optional*, defaults to 0.02):
59
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
61
- The epsilon used by the rms normalization layers.
62
- use_cache (`bool`, *optional*, defaults to `True`):
63
- Whether or not the model should return the last key/values attentions (not used by all models). Only
64
- relevant if `config.is_decoder=True`.
65
- pad_token_id (`int`, *optional*):
66
- Padding token id.
67
- bos_token_id (`int`, *optional*, defaults to 1):
68
- Beginning of stream token id.
69
- eos_token_id (`int`, *optional*, defaults to 2):
70
- End of stream token id.
71
- pretraining_tp (`int`, *optional*, defaults to 1):
72
- Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
73
- document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism)
74
- to understand more about it. This value is necessary to ensure exact reproducibility
75
- of the pretraining results. Please refer to [this
76
- issue](https://github.com/pytorch/pytorch/issues/76232).
77
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
78
- Whether to tie weight embeddings
79
- rope_theta (`float`, *optional*, defaults to 10000.0):
80
- The base period of the RoPE embeddings.
81
- rope_scaling (`Dict`, *optional*):
82
- Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
83
- strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
84
- `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
85
- `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
86
- these scaling strategies behave:
87
- https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
88
- experimental feature, subject to breaking API changes in future versions.
89
- """
90
- _auto_class = 'AutoConfig'
91
- model_type = 'internlm2'
92
- keys_to_ignore_at_inference = ['past_key_values']
93
-
94
- def __init__( # pylint: disable=W0102
95
- self,
96
- vocab_size=103168,
97
- hidden_size=4096,
98
- intermediate_size=11008,
99
- num_hidden_layers=32,
100
- num_attention_heads=32,
101
- num_key_value_heads=None,
102
- hidden_act='silu',
103
- max_position_embeddings=2048,
104
- initializer_range=0.02,
105
- rms_norm_eps=1e-6,
106
- use_cache=True,
107
- pad_token_id=0,
108
- bos_token_id=1,
109
- eos_token_id=2,
110
- pretraining_tp=1,
111
- tie_word_embeddings=False,
112
- bias=True,
113
- rope_theta=10000,
114
- rope_scaling=None,
115
- attn_implementation=None,
116
- **kwargs,
117
- ):
118
- self.vocab_size = vocab_size
119
- self.max_position_embeddings = max_position_embeddings
120
- self.hidden_size = hidden_size
121
- self.intermediate_size = intermediate_size
122
- self.num_hidden_layers = num_hidden_layers
123
- self.num_attention_heads = num_attention_heads
124
- self.bias = bias
125
-
126
- if num_key_value_heads is None:
127
- num_key_value_heads = num_attention_heads
128
- self.num_key_value_heads = num_key_value_heads
129
-
130
- self.hidden_act = hidden_act
131
- self.initializer_range = initializer_range
132
- self.rms_norm_eps = rms_norm_eps
133
- self.pretraining_tp = pretraining_tp
134
- self.use_cache = use_cache
135
- self.rope_theta = rope_theta
136
- self.rope_scaling = rope_scaling
137
- self._rope_scaling_validation()
138
- self.attn_implementation = attn_implementation
139
- if self.attn_implementation is None:
140
- self.attn_implementation = 'eager'
141
-
142
- super().__init__(
143
- pad_token_id=pad_token_id,
144
- bos_token_id=bos_token_id,
145
- eos_token_id=eos_token_id,
146
- tie_word_embeddings=tie_word_embeddings,
147
- **kwargs,
148
- )
149
-
150
- def _rope_scaling_validation(self):
151
- """
152
- Validate the `rope_scaling` configuration.
153
- """
154
- if self.rope_scaling is None:
155
- return
156
-
157
- if not isinstance(self.rope_scaling,
158
- dict) or len(self.rope_scaling) != 2:
159
- raise ValueError(
160
- '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
161
- f'got {self.rope_scaling}')
162
- rope_scaling_type = self.rope_scaling.get('type', None)
163
- rope_scaling_factor = self.rope_scaling.get('factor', None)
164
- if rope_scaling_type is None or rope_scaling_type not in [
165
- 'linear', 'dynamic'
166
- ]:
167
- raise ValueError(
168
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
169
- )
170
- if (rope_scaling_factor is None
171
- or not isinstance(rope_scaling_factor,
172
- (float, int)) or rope_scaling_factor < 1.0):
173
- raise ValueError(
174
- f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
175
- f'of type {type(rope_scaling_factor)}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/modelings/internlm2/modeling_internlm2.py DELETED
@@ -1,1899 +0,0 @@
1
- # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """PyTorch InternLM2.5 model."""
17
- import math
18
- import queue
19
- import threading
20
- from typing import List, Optional, Tuple, Union
21
-
22
- import torch
23
- import torch.nn.functional as F
24
- import torch.utils.checkpoint
25
- from einops import rearrange
26
- from torch import nn
27
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
- from transformers.activations import ACT2FN
29
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
31
- from transformers.modeling_outputs import (BaseModelOutputWithPast,
32
- CausalLMOutputWithPast,
33
- QuestionAnsweringModelOutput,
34
- SequenceClassifierOutputWithPast,
35
- TokenClassifierOutput)
36
- from transformers.modeling_utils import PreTrainedModel
37
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
38
- from transformers.utils import (add_start_docstrings,
39
- add_start_docstrings_to_model_forward,
40
- is_flash_attn_greater_or_equal_2_10, logging,
41
- replace_return_docstrings)
42
-
43
- try:
44
- from transformers.generation.streamers import BaseStreamer
45
- except Exception:
46
- BaseStreamer = None
47
-
48
- from .configuration_internlm2 import InternLM2Config
49
-
50
- try:
51
- from flash_attn import flash_attn_func, flash_attn_varlen_func
52
- from flash_attn.bert_padding import (index_first_axis, pad_input,
53
- unpad_input)
54
- except:
55
- pass
56
-
57
- logger = logging.get_logger(__name__)
58
-
59
- _CONFIG_FOR_DOC = 'InternLM2Config'
60
-
61
-
62
- def _get_unpad_data(attention_mask):
63
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
64
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
65
- max_seqlen_in_batch = seqlens_in_batch.max().item()
66
- cu_seqlens = F.pad(
67
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102
68
- return (
69
- indices,
70
- cu_seqlens,
71
- max_seqlen_in_batch,
72
- )
73
-
74
-
75
- class InternLM2RMSNorm(nn.Module):
76
- """InternLM2RMSNorm is equivalent to T5LayerNorm."""
77
-
78
- def __init__(self, hidden_size, eps=1e-6):
79
- super().__init__()
80
- self.weight = nn.Parameter(torch.ones(hidden_size))
81
- self.variance_epsilon = eps
82
-
83
- def forward(self, hidden_states):
84
- input_dtype = hidden_states.dtype
85
- hidden_states = hidden_states.to(torch.float32)
86
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
87
- hidden_states = hidden_states * torch.rsqrt(variance +
88
- self.variance_epsilon)
89
- return self.weight * hidden_states.to(input_dtype)
90
-
91
-
92
- ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm)
93
-
94
-
95
- class InternLM2RotaryEmbedding(nn.Module):
96
- """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains."""
97
-
98
- def __init__(self,
99
- dim,
100
- max_position_embeddings=2048,
101
- base=10000,
102
- device=None,
103
- scaling_factor=1.0):
104
- super().__init__()
105
- self.scaling_factor = scaling_factor
106
- self.dim = dim
107
- self.max_position_embeddings = max_position_embeddings
108
- self.base = base
109
- inv_freq = 1.0 / (
110
- self.base
111
- **(torch.arange(0, self.dim, 2,
112
- dtype=torch.int64).float().to(device) / self.dim))
113
- self.register_buffer('inv_freq', inv_freq, persistent=False)
114
- # For BC we register cos and sin cached
115
- self.max_seq_len_cached = max_position_embeddings
116
-
117
- @torch.no_grad()
118
- def forward(self, x, position_ids):
119
- # x: [bs, num_attention_heads, seq_len, head_size]
120
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
121
- position_ids.shape[0], -1, 1)
122
- position_ids_expanded = position_ids[:, None, :].float()
123
- # Force float32 since bfloat16 loses precision on long contexts
124
- # See https://github.com/huggingface/transformers/pull/29285
125
- device_type = x.device.type
126
- device_type = device_type if isinstance(
127
- device_type, str) and device_type != 'mps' else 'cpu'
128
- with torch.autocast(device_type=device_type, enabled=False):
129
- freqs = (inv_freq_expanded.float()
130
- @ position_ids_expanded.float()).transpose(1, 2)
131
- emb = torch.cat((freqs, freqs), dim=-1)
132
- cos = emb.cos()
133
- sin = emb.sin()
134
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
135
-
136
-
137
- class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
138
- """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
139
-
140
- def forward(self, x, position_ids):
141
- # difference to the original RoPE: a scaling factor is aplied to the position ids
142
- position_ids = position_ids.float() / self.scaling_factor
143
- cos, sin = super().forward(x, position_ids)
144
- return cos, sin
145
-
146
-
147
- class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
148
- """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
149
- Credits to the Reddit users /u/bloc97 and /u/emozilla"""
150
-
151
- def forward(self, x, position_ids):
152
- # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
153
- seq_len = torch.max(position_ids) + 1
154
- if seq_len > self.max_position_embeddings:
155
- base = self.base * ((self.scaling_factor * seq_len /
156
- self.max_position_embeddings) -
157
- (self.scaling_factor - 1))**(
158
- self.dim / (self.dim - 2))
159
- inv_freq = 1.0 / (
160
- base
161
- **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(
162
- x.device) / self.dim))
163
- self.register_buffer(
164
- 'inv_freq', inv_freq,
165
- persistent=False) # TODO joao: this may break with compilation
166
-
167
- cos, sin = super().forward(x, position_ids)
168
- return cos, sin
169
-
170
-
171
- def rotate_half(x):
172
- """Rotates half the hidden dims of the input."""
173
- x1 = x[..., :x.shape[-1] // 2]
174
- x2 = x[..., x.shape[-1] // 2:]
175
- return torch.cat((-x2, x1), dim=-1)
176
-
177
-
178
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument
179
- """Applies Rotary Position Embedding to the query and key tensors.
180
- Args:
181
- q (`torch.Tensor`): The query tensor.
182
- k (`torch.Tensor`): The key tensor.
183
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
184
- sin (`torch.Tensor`): The sine part of the rotary embedding.
185
- position_ids (`torch.Tensor`, *optional*):
186
- Deprecated and unused.
187
- unsqueeze_dim (`int`, *optional*, defaults to 1):
188
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
189
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
190
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
191
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
192
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
193
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
194
- Returns:
195
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
196
- """
197
- cos = cos.unsqueeze(unsqueeze_dim)
198
- sin = sin.unsqueeze(unsqueeze_dim)
199
- q_embed = (q * cos) + (rotate_half(q) * sin)
200
- k_embed = (k * cos) + (rotate_half(k) * sin)
201
- return q_embed, k_embed
202
-
203
-
204
- class InternLM2MLP(nn.Module):
205
- """MLP for InternLM2 model."""
206
-
207
- def __init__(self, config):
208
- super().__init__()
209
- self.config = config
210
- self.hidden_size = config.hidden_size
211
- self.intermediate_size = config.intermediate_size
212
- self.w1 = nn.Linear(
213
- self.hidden_size, self.intermediate_size, bias=False)
214
- self.w3 = nn.Linear(
215
- self.hidden_size, self.intermediate_size, bias=False)
216
- self.w2 = nn.Linear(
217
- self.intermediate_size, self.hidden_size, bias=False)
218
- self.act_fn = ACT2FN[config.hidden_act]
219
-
220
- def forward(self, x):
221
- down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
222
-
223
- return down_proj
224
-
225
-
226
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
227
- """
228
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
229
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
230
- """
231
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
232
- if n_rep == 1:
233
- return hidden_states
234
- hidden_states = hidden_states[:, :,
235
- None, :, :].expand(batch,
236
- num_key_value_heads,
237
- n_rep, slen, head_dim)
238
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
239
- head_dim)
240
-
241
-
242
- class InternLM2Attention(nn.Module):
243
- """Multi-headed attention from 'Attention Is All You Need' paper"""
244
-
245
- def __init__(self,
246
- config: InternLM2Config,
247
- layer_idx: Optional[int] = None):
248
- super().__init__()
249
- self.config = config
250
- self.layer_idx = layer_idx
251
- if layer_idx is None:
252
- logger.warning_once(
253
- f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
254
- 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
255
- 'when creating this class.')
256
-
257
- self.hidden_size = config.hidden_size
258
- self.num_heads = config.num_attention_heads
259
- self.head_dim = self.hidden_size // self.num_heads
260
- self.num_key_value_heads = config.num_key_value_heads
261
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
262
- self.max_position_embeddings = config.max_position_embeddings
263
- self.rope_theta = config.rope_theta
264
- self.is_causal = True
265
-
266
- if (self.head_dim * self.num_heads) != self.hidden_size:
267
- raise ValueError(
268
- f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
269
- f' and `num_heads`: {self.num_heads}).')
270
-
271
- self.wqkv = nn.Linear(
272
- self.hidden_size,
273
- (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
274
- bias=config.bias,
275
- )
276
- self.wo = nn.Linear(
277
- self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
278
-
279
- self._init_rope()
280
-
281
- def _init_rope(self):
282
- if self.config.rope_scaling is None:
283
- self.rotary_emb = InternLM2RotaryEmbedding(
284
- self.head_dim,
285
- max_position_embeddings=self.max_position_embeddings,
286
- base=self.rope_theta,
287
- )
288
- else:
289
- scaling_type = self.config.rope_scaling['type']
290
- scaling_factor = self.config.rope_scaling['factor']
291
- if scaling_type == 'linear':
292
- self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
293
- self.head_dim,
294
- max_position_embeddings=self.max_position_embeddings,
295
- scaling_factor=scaling_factor,
296
- base=self.rope_theta,
297
- )
298
- elif scaling_type == 'dynamic':
299
- self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
300
- self.head_dim,
301
- max_position_embeddings=self.max_position_embeddings,
302
- scaling_factor=scaling_factor,
303
- base=self.rope_theta,
304
- )
305
- else:
306
- raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
307
-
308
- def forward(
309
- self,
310
- hidden_states: torch.Tensor,
311
- attention_mask: Optional[torch.Tensor] = None,
312
- position_ids: Optional[torch.LongTensor] = None,
313
- past_key_value: Optional[Cache] = None,
314
- output_attentions: bool = False,
315
- use_cache: bool = False, # pylint: disable=unused-argument
316
- cache_position: Optional[torch.LongTensor] = None,
317
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
318
- Optional[Tuple[torch.Tensor]]]:
319
- bsz, q_len, _ = hidden_states.size()
320
-
321
- if self.config.pretraining_tp > 1:
322
- # split qkv_states by tp size
323
- key_value_slicing = (self.num_key_value_heads *
324
- self.head_dim) // self.config.pretraining_tp
325
- qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0)
326
- qkv_states = torch.cat(
327
- [
328
- F.linear(hidden_states, qkv_slice)
329
- for qkv_slice in qkv_slices
330
- ],
331
- dim=-1 # pylint: disable=E1102
332
- )
333
- else:
334
- qkv_states = self.wqkv(hidden_states)
335
-
336
- qkv_states = rearrange(
337
- qkv_states,
338
- 'b q (h gs d) -> b q h gs d',
339
- gs=2 + self.num_key_value_groups,
340
- d=self.head_dim,
341
- )
342
-
343
- query_states = qkv_states[..., :self.num_key_value_groups, :]
344
- query_states = rearrange(query_states,
345
- 'b q h gs d -> b q (h gs) d').transpose(1, 2)
346
- key_states = qkv_states[..., -2, :].transpose(1, 2)
347
- value_states = qkv_states[..., -1, :].transpose(1, 2)
348
-
349
- cos, sin = self.rotary_emb(value_states, position_ids)
350
- query_states, key_states = apply_rotary_pos_emb(
351
- query_states, key_states, cos, sin, position_ids)
352
-
353
- if past_key_value is not None:
354
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
355
- cache_kwargs = {
356
- 'sin': sin,
357
- 'cos': cos,
358
- 'cache_position': cache_position
359
- }
360
- key_states, value_states = past_key_value.update(
361
- key_states, value_states, self.layer_idx, cache_kwargs)
362
-
363
- key_states = repeat_kv(key_states, self.num_key_value_groups)
364
- value_states = repeat_kv(value_states, self.num_key_value_groups)
365
-
366
- attn_weights = torch.matmul(query_states, key_states.transpose(
367
- 2, 3)) / math.sqrt(self.head_dim)
368
-
369
- if attention_mask is not None: # no matter the length, we just slice it
370
- causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
371
- attn_weights = attn_weights + causal_mask
372
-
373
- # upcast attention to fp32
374
- attn_weights = nn.functional.softmax(
375
- attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
376
- attn_output = torch.matmul(attn_weights, value_states)
377
-
378
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
379
- raise ValueError(
380
- f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
381
- f' {attn_output.size()}')
382
-
383
- attn_output = attn_output.transpose(1, 2).contiguous()
384
-
385
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
386
-
387
- if self.config.pretraining_tp > 1:
388
- attn_output = attn_output.split(
389
- self.hidden_size // self.config.pretraining_tp, dim=2)
390
- o_proj_slices = self.wo.weight.split(
391
- self.hidden_size // self.config.pretraining_tp, dim=1)
392
- attn_output = sum([
393
- F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102
394
- for i in range(self.config.pretraining_tp)
395
- ])
396
- else:
397
- attn_output = self.wo(attn_output)
398
-
399
- if not output_attentions:
400
- attn_weights = None
401
-
402
- return attn_output, attn_weights, past_key_value
403
-
404
-
405
- class InternLM2FlashAttention2(InternLM2Attention):
406
- """
407
- InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
408
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
409
- flash attention and deal with padding tokens in case the input contains any of them.
410
- """
411
-
412
- def __init__(self, *args, **kwargs):
413
- super().__init__(*args, **kwargs)
414
-
415
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
416
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement,
417
- # that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
418
- # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
419
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1)
420
- # produces a wrong mask (top-left).
421
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
422
- )
423
-
424
- def forward(
425
- self,
426
- hidden_states: torch.Tensor,
427
- attention_mask: Optional[torch.LongTensor] = None,
428
- position_ids: Optional[torch.LongTensor] = None,
429
- past_key_value: Optional[Cache] = None,
430
- output_attentions: bool = False,
431
- use_cache: bool = False,
432
- cache_position: Optional[torch.LongTensor] = None,
433
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
434
- Optional[Tuple[torch.Tensor]]]:
435
- if isinstance(past_key_value, StaticCache):
436
- raise ValueError(
437
- '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` '
438
- 'make sure to use `sdpa` in the mean time, and open an issue at '
439
- 'https://github.com/huggingface/transformers')
440
-
441
- output_attentions = False
442
-
443
- bsz, q_len, _ = hidden_states.size()
444
-
445
- qkv_states = self.wqkv(hidden_states)
446
-
447
- qkv_states = rearrange(
448
- qkv_states,
449
- 'b q (h gs d) -> b q h gs d',
450
- gs=2 + self.num_key_value_groups,
451
- d=self.head_dim,
452
- )
453
-
454
- query_states = qkv_states[..., :self.num_key_value_groups, :]
455
- query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
456
- key_states = qkv_states[..., -2, :]
457
- value_states = qkv_states[..., -1, :]
458
-
459
- query_states = query_states.transpose(1, 2)
460
- key_states = key_states.transpose(1, 2)
461
- value_states = value_states.transpose(1, 2)
462
-
463
- cos, sin = self.rotary_emb(value_states, position_ids)
464
- query_states, key_states = apply_rotary_pos_emb(
465
- query_states, key_states, cos, sin)
466
-
467
- if past_key_value is not None:
468
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
469
- cache_kwargs = {
470
- 'sin': sin,
471
- 'cos': cos,
472
- 'cache_position': cache_position
473
- }
474
- key_states, value_states = past_key_value.update(
475
- key_states, value_states, self.layer_idx, cache_kwargs)
476
-
477
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout
478
- # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
479
- # to be able to avoid many of these transpose/reshape/view.
480
- query_states = query_states.transpose(1, 2)
481
- key_states = key_states.transpose(1, 2)
482
- value_states = value_states.transpose(1, 2)
483
-
484
- # dropout_rate = self.attention_dropout if self.training else 0.0
485
- dropout_rate = 0.0
486
-
487
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
488
- # therefore the input hidden states gets silently casted in float32. Hence, we need
489
- # cast them back in the correct dtype just to be sure everything works as expected.
490
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
491
- # in fp32. (InternLM2RMSNorm handles it correctly)
492
-
493
- input_dtype = query_states.dtype
494
- if input_dtype == torch.float32:
495
- if torch.is_autocast_enabled():
496
- target_dtype = torch.get_autocast_gpu_dtype()
497
- # Handle the case where the model is quantized
498
- elif hasattr(self.config, '_pre_quantization_dtype'):
499
- target_dtype = self.config._pre_quantization_dtype
500
- else:
501
- target_dtype = self.wqkv.weight.dtype
502
-
503
- logger.warning_once(
504
- f'The input hidden states seems to be silently casted in float32, this might be related to'
505
- f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
506
- f' {target_dtype}.')
507
-
508
- query_states = query_states.to(target_dtype)
509
- key_states = key_states.to(target_dtype)
510
- value_states = value_states.to(target_dtype)
511
-
512
- attn_output = self._flash_attention_forward(
513
- query_states,
514
- key_states,
515
- value_states,
516
- attention_mask,
517
- q_len,
518
- dropout=dropout_rate)
519
-
520
- attn_output = attn_output.reshape(bsz, q_len,
521
- self.hidden_size).contiguous()
522
- attn_output = self.wo(attn_output)
523
-
524
- if not output_attentions:
525
- attn_weights = None
526
-
527
- return attn_output, attn_weights, past_key_value # pylint: disable=E0606
528
-
529
- def _flash_attention_forward(self,
530
- query_states,
531
- key_states,
532
- value_states,
533
- attention_mask,
534
- query_length,
535
- dropout=0.0,
536
- softmax_scale=None):
537
- """
538
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
539
- first unpad the input, then computes the attention scores and pad the final attention scores.
540
- Args:
541
- query_states (`torch.Tensor`):
542
- Input query states to be passed to Flash Attention API
543
- key_states (`torch.Tensor`):
544
- Input key states to be passed to Flash Attention API
545
- value_states (`torch.Tensor`):
546
- Input value states to be passed to Flash Attention API
547
- attention_mask (`torch.Tensor`):
548
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
549
- position of padding tokens and 1 for the position of non-padding tokens.
550
- dropout (`float`):
551
- Attention dropout
552
- softmax_scale (`float`, *optional*):
553
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
554
- """
555
- if not self._flash_attn_uses_top_left_mask:
556
- causal = self.is_causal
557
- else:
558
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
559
- # For details, please see the comment in InternLM2FlashAttention2 __init__.
560
- causal = self.is_causal and query_length != 1
561
-
562
- # Contains at least one padding token in the sequence
563
- if attention_mask is not None:
564
- batch_size = query_states.shape[0]
565
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
566
- query_states, key_states, value_states, attention_mask,
567
- query_length)
568
-
569
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
570
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
571
-
572
- attn_output_unpad = flash_attn_varlen_func( # pylint: disable=E0606
573
- query_states,
574
- key_states,
575
- value_states,
576
- cu_seqlens_q=cu_seqlens_q,
577
- cu_seqlens_k=cu_seqlens_k,
578
- max_seqlen_q=max_seqlen_in_batch_q,
579
- max_seqlen_k=max_seqlen_in_batch_k,
580
- dropout_p=dropout,
581
- softmax_scale=softmax_scale,
582
- causal=causal,
583
- )
584
-
585
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
586
- query_length) # pylint: disable=E0606
587
- else:
588
- attn_output = flash_attn_func( # pylint: disable=E0606
589
- query_states,
590
- key_states,
591
- value_states,
592
- dropout,
593
- softmax_scale=softmax_scale,
594
- causal=causal)
595
-
596
- return attn_output
597
-
598
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
599
- query_length):
600
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
601
- attention_mask)
602
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
603
-
604
- key_layer = index_first_axis( # pylint: disable=E0606
605
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
606
- head_dim), indices_k)
607
- value_layer = index_first_axis( # pylint: disable=E0606
608
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
609
- head_dim), indices_k)
610
- if query_length == kv_seq_len:
611
- query_layer = index_first_axis( # pylint: disable=E0606
612
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
613
- head_dim), indices_k)
614
- cu_seqlens_q = cu_seqlens_k
615
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
616
- indices_q = indices_k
617
- elif query_length == 1:
618
- max_seqlen_in_batch_q = 1
619
- cu_seqlens_q = torch.arange(
620
- batch_size + 1, dtype=torch.int32, device=query_layer.device
621
- ) # There is a memcpy here, that is very bad.
622
- indices_q = cu_seqlens_q[:-1]
623
- query_layer = query_layer.squeeze(1)
624
- else:
625
- # The -q_len: slice assumes left padding.
626
- attention_mask = attention_mask[:, -query_length:]
627
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( # pylint: disable=E0606
628
- query_layer, attention_mask)
629
-
630
- return (
631
- query_layer,
632
- key_layer,
633
- value_layer,
634
- indices_q,
635
- (cu_seqlens_q, cu_seqlens_k),
636
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
637
- )
638
-
639
-
640
- # Copied from transformers.models.llama.modeling_llama.LllamaSdpaAttention with Llama->InternLM2
641
- class InternLM2SdpaAttention(InternLM2Attention):
642
- """
643
- InternLM2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
644
- `InternLM2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
645
- to adapt to SDPA API.
646
- """
647
-
648
- # Adapted from InternLM2Attention.forward
649
- def forward(
650
- self,
651
- hidden_states: torch.Tensor,
652
- attention_mask: Optional[torch.Tensor] = None,
653
- position_ids: Optional[torch.LongTensor] = None,
654
- past_key_value: Optional[Cache] = None,
655
- output_attentions: bool = False,
656
- use_cache: bool = False,
657
- cache_position: Optional[torch.LongTensor] = None,
658
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
659
- Optional[Tuple[torch.Tensor]]]:
660
- if output_attentions:
661
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"`
662
- # once this is implemented.
663
- logger.warning_once(
664
- 'InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` '
665
- 'does not support `output_attentions=True`. Falling back to the manual attention implementation, '
666
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. '
667
- 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
668
- )
669
- return super().forward(
670
- hidden_states=hidden_states,
671
- attention_mask=attention_mask,
672
- position_ids=position_ids,
673
- past_key_value=past_key_value,
674
- output_attentions=output_attentions,
675
- use_cache=use_cache,
676
- cache_position=cache_position,
677
- )
678
-
679
- bsz, q_len, _ = hidden_states.size()
680
-
681
- qkv_states = self.wqkv(hidden_states)
682
-
683
- qkv_states = rearrange(
684
- qkv_states,
685
- 'b q (h gs d) -> b q h gs d',
686
- gs=2 + self.num_key_value_groups,
687
- d=self.head_dim,
688
- )
689
-
690
- query_states = qkv_states[..., :self.num_key_value_groups, :]
691
- query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
692
- key_states = qkv_states[..., -2, :]
693
- value_states = qkv_states[..., -1, :]
694
-
695
- query_states = query_states.transpose(1, 2)
696
- key_states = key_states.transpose(1, 2)
697
- value_states = value_states.transpose(1, 2)
698
-
699
- cos, sin = self.rotary_emb(value_states, position_ids)
700
- query_states, key_states = apply_rotary_pos_emb(
701
- query_states, key_states, cos, sin)
702
-
703
- if past_key_value is not None:
704
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
705
- cache_kwargs = {
706
- 'sin': sin,
707
- 'cos': cos,
708
- 'cache_position': cache_position
709
- }
710
- key_states, value_states = past_key_value.update(
711
- key_states, value_states, self.layer_idx, cache_kwargs)
712
-
713
- key_states = repeat_kv(key_states, self.num_key_value_groups)
714
- value_states = repeat_kv(value_states, self.num_key_value_groups)
715
-
716
- causal_mask = attention_mask
717
- if attention_mask is not None:
718
- causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
719
-
720
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
721
- # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577.
722
- if query_states.device.type == 'cuda' and causal_mask is not None:
723
- query_states = query_states.contiguous()
724
- key_states = key_states.contiguous()
725
- value_states = value_states.contiguous()
726
-
727
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
728
- # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
729
- # options. An inline conditional prevents dynamic shapes from compiling.
730
- is_causal = bool(causal_mask is None and q_len > 1)
731
-
732
- attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
733
- query_states,
734
- key_states,
735
- value_states,
736
- attn_mask=causal_mask,
737
- dropout_p=0.0,
738
- is_causal=is_causal,
739
- )
740
-
741
- attn_output = attn_output.transpose(1, 2).contiguous()
742
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
743
-
744
- attn_output = self.wo(attn_output)
745
-
746
- return attn_output, None, past_key_value
747
-
748
-
749
- INTERNLM2_ATTENTION_CLASSES = {
750
- 'eager': InternLM2Attention,
751
- 'flash_attention_2': InternLM2FlashAttention2,
752
- 'sdpa': InternLM2SdpaAttention,
753
- }
754
-
755
-
756
- # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2
757
- class InternLM2DecoderLayer(nn.Module):
758
- """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model."""
759
-
760
- def __init__(self, config: InternLM2Config, layer_idx: int):
761
- super().__init__()
762
- self.hidden_size = config.hidden_size
763
- self.layer_idx = layer_idx
764
-
765
- self.attention = INTERNLM2_ATTENTION_CLASSES[
766
- config.attn_implementation](
767
- config=config, layer_idx=layer_idx)
768
-
769
- self.feed_forward = InternLM2MLP(config)
770
- self.attention_norm = InternLM2RMSNorm(
771
- config.hidden_size, eps=config.rms_norm_eps)
772
- self.ffn_norm = InternLM2RMSNorm(
773
- config.hidden_size, eps=config.rms_norm_eps)
774
-
775
- def forward(
776
- self,
777
- hidden_states: torch.Tensor,
778
- attention_mask: Optional[torch.Tensor] = None,
779
- position_ids: Optional[torch.LongTensor] = None,
780
- past_key_value: Optional[Cache] = None,
781
- output_attentions: Optional[bool] = False,
782
- use_cache: Optional[bool] = False,
783
- cache_position: Optional[torch.LongTensor] = None,
784
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
785
- torch.FloatTensor]]]:
786
- """
787
- Args:
788
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
789
- attention_mask (`torch.FloatTensor`, *optional*):
790
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
791
- query_sequence_length, key_sequence_length)` if default attention is used.
792
- output_attentions (`bool`, *optional*):
793
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
794
- returned tensors for more detail.
795
- use_cache (`bool`, *optional*):
796
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
797
- (see `past_key_values`).
798
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
799
- """
800
- residual = hidden_states
801
-
802
- hidden_states = self.attention_norm(hidden_states)
803
-
804
- # Self Attention
805
- hidden_states, self_attn_weights, present_key_value = self.attention(
806
- hidden_states=hidden_states,
807
- attention_mask=attention_mask,
808
- position_ids=position_ids,
809
- past_key_value=past_key_value,
810
- output_attentions=output_attentions,
811
- use_cache=use_cache,
812
- cache_position=cache_position,
813
- )
814
- hidden_states = residual + hidden_states
815
-
816
- # Fully Connected
817
- residual = hidden_states
818
- hidden_states = self.ffn_norm(hidden_states)
819
- hidden_states = self.feed_forward(hidden_states)
820
- hidden_states = residual + hidden_states
821
-
822
- outputs = (hidden_states, )
823
-
824
- if output_attentions:
825
- outputs += (self_attn_weights, )
826
-
827
- if use_cache:
828
- outputs += (present_key_value, )
829
-
830
- return outputs
831
-
832
-
833
- InternLM2_START_DOCSTRING = r"""
834
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
835
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
836
- etc.)
837
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
838
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
839
- and behavior.
840
- Parameters:
841
- config ([`InternLM2Config`]):
842
- Model configuration class with all the parameters of the model. Initializing with a config file does not
843
- load the weights associated with the model, only the configuration. Check out the
844
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
845
- """
846
-
847
-
848
- # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
849
- @add_start_docstrings(
850
- 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
851
- InternLM2_START_DOCSTRING,
852
- )
853
- class InternLM2PreTrainedModel(PreTrainedModel):
854
- """
855
- InternLM2 pretraiend model's base class.
856
- """
857
-
858
- config_class = InternLM2Config
859
- base_model_prefix = 'model'
860
- supports_gradient_checkpointing = True
861
- _no_split_modules = ['InternLM2DecoderLayer']
862
- _skip_keys_device_placement = ['past_key_values']
863
- _supports_flash_attn_2 = True
864
- _supports_sdpa = True
865
- _supports_cache_class = True
866
- _supports_quantized_cache = True
867
- _supports_static_cache = True
868
-
869
- def _init_weights(self, module):
870
- std = self.config.initializer_range
871
- if isinstance(module, nn.Linear):
872
- module.weight.data.normal_(mean=0.0, std=std)
873
- if module.bias is not None:
874
- module.bias.data.zero_()
875
- elif isinstance(module, nn.Embedding):
876
- module.weight.data.normal_(mean=0.0, std=std)
877
- if module.padding_idx is not None:
878
- module.weight.data[module.padding_idx].zero_()
879
-
880
-
881
- InternLM2_INPUTS_DOCSTRING = r"""
882
- Args:
883
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
884
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
885
- it.
886
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
887
- [`PreTrainedTokenizer.__call__`] for details.
888
- [What are input IDs?](../glossary#input-ids)
889
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
890
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
891
- - 1 for tokens that are **not masked**,
892
- - 0 for tokens that are **masked**.
893
- [What are attention masks?](../glossary#attention-mask)
894
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
895
- [`PreTrainedTokenizer.__call__`] for details.
896
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
897
- `past_key_values`).
898
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
899
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
900
- information on the default strategy.
901
- - 1 indicates the head is **not masked**,
902
- - 0 indicates the head is **masked**.
903
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
904
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
905
- config.n_positions - 1]`.
906
- [What are position IDs?](../glossary#position-ids)
907
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
908
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
909
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
910
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
911
- Two formats are allowed:
912
- - a [`~cache_utils.Cache`] instance;
913
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
914
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
915
- cache format.
916
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
917
- legacy cache format will be returned.
918
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
919
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
920
- of shape `(batch_size, sequence_length)`.
921
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
922
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
923
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
924
- model's internal embedding lookup matrix.
925
- use_cache (`bool`, *optional*):
926
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
927
- `past_key_values`).
928
- output_attentions (`bool`, *optional*):
929
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
930
- tensors for more detail.
931
- output_hidden_states (`bool`, *optional*):
932
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
933
- more detail.
934
- return_dict (`bool`, *optional*):
935
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
936
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
937
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
938
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
939
- the complete sequence length.
940
- """
941
-
942
-
943
- # Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2
944
- @add_start_docstrings(
945
- 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
946
- InternLM2_START_DOCSTRING,
947
- )
948
- class InternLM2Model(InternLM2PreTrainedModel):
949
- """
950
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
951
- Args:
952
- config: InternLM2Config
953
- """
954
-
955
- _auto_class = 'AutoModel'
956
-
957
- def __init__(self, config: InternLM2Config):
958
- super().__init__(config)
959
- self.padding_idx = config.pad_token_id
960
- self.vocab_size = config.vocab_size
961
- self.config = config
962
-
963
- self.tok_embeddings = nn.Embedding(config.vocab_size,
964
- config.hidden_size,
965
- self.padding_idx)
966
-
967
- self.layers = nn.ModuleList([
968
- InternLM2DecoderLayer(config, layer_idx)
969
- for layer_idx in range(config.num_hidden_layers)
970
- ])
971
- self.norm = InternLM2RMSNorm(
972
- config.hidden_size, eps=config.rms_norm_eps)
973
-
974
- self.gradient_checkpointing = False
975
- # Initialize weights and apply final processing
976
- self.post_init()
977
-
978
- def get_input_embeddings(self):
979
- return self.tok_embeddings
980
-
981
- def set_input_embeddings(self, value):
982
- self.tok_embeddings = value
983
-
984
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
985
- def forward(
986
- self,
987
- input_ids: torch.LongTensor = None,
988
- attention_mask: Optional[torch.Tensor] = None,
989
- position_ids: Optional[torch.LongTensor] = None,
990
- past_key_values: Optional[Union[Cache,
991
- List[torch.FloatTensor]]] = None,
992
- inputs_embeds: Optional[torch.FloatTensor] = None,
993
- use_cache: Optional[bool] = None,
994
- output_attentions: Optional[bool] = None,
995
- output_hidden_states: Optional[bool] = None,
996
- return_dict: Optional[bool] = None,
997
- cache_position: Optional[torch.LongTensor] = None,
998
- ) -> Union[Tuple, BaseModelOutputWithPast]:
999
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1000
- output_hidden_states = (
1001
- output_hidden_states if output_hidden_states is not None else
1002
- self.config.output_hidden_states)
1003
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1004
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1005
-
1006
- if (input_ids is None) ^ (inputs_embeds is not None):
1007
- raise ValueError(
1008
- 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
1009
- )
1010
-
1011
- if self.gradient_checkpointing and self.training and use_cache:
1012
- logger.warning_once(
1013
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
1014
- )
1015
- use_cache = False
1016
-
1017
- if inputs_embeds is None:
1018
- inputs_embeds = self.tok_embeddings(input_ids)
1019
-
1020
- return_legacy_cache = False
1021
- if use_cache and not isinstance(
1022
- past_key_values,
1023
- Cache): # kept for BC (non `Cache` `past_key_values` inputs)
1024
- return_legacy_cache = True
1025
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1026
-
1027
- if cache_position is None:
1028
- past_seen_tokens = past_key_values.get_seq_length(
1029
- ) if past_key_values is not None else 0
1030
- cache_position = torch.arange(
1031
- past_seen_tokens,
1032
- past_seen_tokens + inputs_embeds.shape[1],
1033
- device=inputs_embeds.device)
1034
- if position_ids is None:
1035
- position_ids = cache_position.unsqueeze(0)
1036
-
1037
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
1038
- cache_position, past_key_values,
1039
- output_attentions)
1040
-
1041
- # embed positions
1042
- hidden_states = inputs_embeds
1043
-
1044
- # decoder layers
1045
- all_hidden_states = () if output_hidden_states else None
1046
- all_self_attns = () if output_attentions else None
1047
- next_decoder_cache = None
1048
-
1049
- for decoder_layer in self.layers:
1050
- if output_hidden_states:
1051
- all_hidden_states += (hidden_states, )
1052
-
1053
- if self.gradient_checkpointing and self.training:
1054
- layer_outputs = self._gradient_checkpointing_func(
1055
- decoder_layer.__call__,
1056
- hidden_states,
1057
- causal_mask,
1058
- position_ids,
1059
- past_key_values,
1060
- output_attentions,
1061
- use_cache,
1062
- cache_position,
1063
- )
1064
- else:
1065
- layer_outputs = decoder_layer(
1066
- hidden_states,
1067
- attention_mask=causal_mask,
1068
- position_ids=position_ids,
1069
- past_key_value=past_key_values,
1070
- output_attentions=output_attentions,
1071
- use_cache=use_cache,
1072
- cache_position=cache_position,
1073
- )
1074
-
1075
- hidden_states = layer_outputs[0]
1076
-
1077
- if use_cache:
1078
- next_decoder_cache = layer_outputs[
1079
- 2 if output_attentions else 1]
1080
-
1081
- if output_attentions:
1082
- all_self_attns += (layer_outputs[1], )
1083
-
1084
- hidden_states = self.norm(hidden_states)
1085
-
1086
- # add hidden states from the last decoder layer
1087
- if output_hidden_states:
1088
- all_hidden_states += (hidden_states, )
1089
-
1090
- next_cache = next_decoder_cache if use_cache else None
1091
- if return_legacy_cache:
1092
- next_cache = next_cache.to_legacy_cache()
1093
-
1094
- if not return_dict:
1095
- return tuple(
1096
- v for v in
1097
- [hidden_states, next_cache, all_hidden_states, all_self_attns]
1098
- if v is not None)
1099
- return BaseModelOutputWithPast(
1100
- last_hidden_state=hidden_states,
1101
- past_key_values=next_cache,
1102
- hidden_states=all_hidden_states,
1103
- attentions=all_self_attns,
1104
- )
1105
-
1106
- def _update_causal_mask(
1107
- self,
1108
- attention_mask: torch.Tensor,
1109
- input_tensor: torch.Tensor,
1110
- cache_position: torch.Tensor,
1111
- past_key_values: Cache,
1112
- output_attentions: bool,
1113
- ):
1114
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length
1115
- # even when the static KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at
1116
- # each decode steps due to the dynamic shapes. (`recording cudagraph tree for symint key 13`, etc.), which is
1117
- # VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`.
1118
- # See more context in https://github.com/huggingface/transformers/pull/29114
1119
-
1120
- if self.config.attn_implementation == 'flash_attention_2':
1121
- if attention_mask is not None and 0.0 in attention_mask:
1122
- return attention_mask
1123
- return None
1124
-
1125
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1126
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1127
- # to infer the attention mask.
1128
- past_seen_tokens = past_key_values.get_seq_length(
1129
- ) if past_key_values is not None else 0
1130
- using_static_cache = isinstance(past_key_values, StaticCache)
1131
-
1132
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1133
- if self.config.attn_implementation == 'sdpa' and not using_static_cache and not output_attentions:
1134
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1135
- attention_mask,
1136
- inputs_embeds=input_tensor,
1137
- past_key_values_length=past_seen_tokens,
1138
- is_training=self.training,
1139
- ):
1140
- return None
1141
-
1142
- dtype, device = input_tensor.dtype, input_tensor.device
1143
- min_dtype = torch.finfo(dtype).min
1144
- sequence_length = input_tensor.shape[1]
1145
- if using_static_cache:
1146
- target_length = past_key_values.get_max_length()
1147
- else:
1148
- target_length = (
1149
- attention_mask.shape[-1] if isinstance(
1150
- attention_mask, torch.Tensor) else past_seen_tokens +
1151
- sequence_length + 1)
1152
-
1153
- if attention_mask is not None and attention_mask.dim() == 4:
1154
- # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1155
- if attention_mask.max() != 0:
1156
- raise ValueError(
1157
- 'Custom 4D attention mask should be passed in inverted form with max==0`'
1158
- )
1159
- causal_mask = attention_mask
1160
- else:
1161
- causal_mask = torch.full((sequence_length, target_length),
1162
- fill_value=min_dtype,
1163
- dtype=dtype,
1164
- device=device)
1165
- if sequence_length != 1:
1166
- causal_mask = torch.triu(causal_mask, diagonal=1)
1167
- causal_mask *= torch.arange(
1168
- target_length, device=device) > cache_position.reshape(-1, 1)
1169
- causal_mask = causal_mask[None, None, :, :].expand(
1170
- input_tensor.shape[0], 1, -1, -1)
1171
- if attention_mask is not None:
1172
- causal_mask = causal_mask.clone(
1173
- ) # copy to contiguous memory for in-place edit
1174
- mask_length = attention_mask.shape[-1]
1175
- padding_mask = causal_mask[:, :, :, :
1176
- mask_length] + attention_mask[:,
1177
- None,
1178
- None, :]
1179
- padding_mask = padding_mask == 0
1180
- causal_mask[:, :, :, :
1181
- mask_length] = causal_mask[:, :, :, :
1182
- mask_length].masked_fill(
1183
- padding_mask,
1184
- min_dtype)
1185
- if (self.config.attn_implementation == 'sdpa'
1186
- and attention_mask is not None
1187
- and attention_mask.device.type == 'cuda'
1188
- and not output_attentions):
1189
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1190
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1191
- # Details: https://github.com/pytorch/pytorch/issues/110213
1192
- causal_mask = AttentionMaskConverter._unmask_unattended(
1193
- causal_mask, min_dtype) # pylint: disable=E1120
1194
-
1195
- return causal_mask
1196
-
1197
-
1198
- # Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM
1199
- class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1200
- """Causal language model (CLM) for InternLM2."""
1201
-
1202
- _auto_class = 'AutoModelForCausalLM'
1203
- _tied_weights_keys = ['output.weight']
1204
-
1205
- def __init__(self, config):
1206
- super().__init__(config)
1207
- self.model = InternLM2Model(config)
1208
- self.vocab_size = config.vocab_size
1209
- self.output = nn.Linear(
1210
- config.hidden_size, config.vocab_size, bias=False)
1211
-
1212
- # Initialize weights and apply final processing
1213
- self.post_init()
1214
-
1215
- def get_input_embeddings(self):
1216
- return self.model.tok_embeddings
1217
-
1218
- def set_input_embeddings(self, value):
1219
- self.model.tok_embeddings = value
1220
-
1221
- def get_output_embeddings(self):
1222
- return self.output
1223
-
1224
- def set_output_embeddings(self, new_embeddings):
1225
- self.output = new_embeddings
1226
-
1227
- def set_decoder(self, decoder):
1228
- self.model = decoder
1229
-
1230
- def get_decoder(self):
1231
- return self.model
1232
-
1233
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1234
- @replace_return_docstrings(
1235
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1236
- def forward(
1237
- self,
1238
- input_ids: torch.LongTensor = None,
1239
- attention_mask: Optional[torch.Tensor] = None,
1240
- position_ids: Optional[torch.LongTensor] = None,
1241
- past_key_values: Optional[Union[Cache,
1242
- List[torch.FloatTensor]]] = None,
1243
- inputs_embeds: Optional[torch.FloatTensor] = None,
1244
- labels: Optional[torch.LongTensor] = None,
1245
- use_cache: Optional[bool] = None,
1246
- output_attentions: Optional[bool] = None,
1247
- output_hidden_states: Optional[bool] = None,
1248
- return_dict: Optional[bool] = None,
1249
- cache_position: Optional[torch.LongTensor] = None,
1250
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1251
- r"""
1252
- Args:
1253
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1254
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1255
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1256
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1257
- Returns:
1258
- Example:
1259
- ```python
1260
- >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1261
- >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
1262
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
1263
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1264
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1265
- >>> # Generate
1266
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1267
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1268
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1269
- ```"""
1270
-
1271
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1272
- output_hidden_states = (
1273
- output_hidden_states if output_hidden_states is not None else
1274
- self.config.output_hidden_states)
1275
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1276
-
1277
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1278
- outputs = self.model(
1279
- input_ids=input_ids,
1280
- attention_mask=attention_mask,
1281
- position_ids=position_ids,
1282
- past_key_values=past_key_values,
1283
- inputs_embeds=inputs_embeds,
1284
- use_cache=use_cache,
1285
- output_attentions=output_attentions,
1286
- output_hidden_states=output_hidden_states,
1287
- return_dict=return_dict,
1288
- cache_position=cache_position,
1289
- )
1290
-
1291
- hidden_states = outputs[0]
1292
- if self.config.pretraining_tp > 1:
1293
- output_slices = self.output.weight.split(
1294
- self.vocab_size // self.config.pretraining_tp, dim=0)
1295
- logits = [
1296
- F.linear(hidden_states, output_slices[i]) # pylint: disable=not-callable
1297
- for i in range(self.config.pretraining_tp)
1298
- ]
1299
- logits = torch.cat(logits, dim=-1)
1300
- else:
1301
- logits = self.output(hidden_states)
1302
- logits = logits.float()
1303
-
1304
- loss = None
1305
- if labels is not None:
1306
- # Shift so that tokens < n predict n
1307
- shift_logits = logits[..., :-1, :].contiguous()
1308
- shift_labels = labels[..., 1:].contiguous()
1309
- # Flatten the tokens
1310
- loss_fct = CrossEntropyLoss()
1311
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1312
- shift_labels = shift_labels.view(-1)
1313
- # Enable model parallelism
1314
- shift_labels = shift_labels.to(shift_logits.device)
1315
- loss = loss_fct(shift_logits, shift_labels)
1316
-
1317
- if not return_dict:
1318
- output = (logits, ) + outputs[1:]
1319
- return (loss, ) + output if loss is not None else output
1320
-
1321
- return CausalLMOutputWithPast(
1322
- loss=loss,
1323
- logits=logits,
1324
- past_key_values=outputs.past_key_values,
1325
- hidden_states=outputs.hidden_states,
1326
- attentions=outputs.attentions,
1327
- )
1328
-
1329
- def prepare_inputs_for_generation(
1330
- self,
1331
- input_ids,
1332
- past_key_values=None,
1333
- attention_mask=None,
1334
- inputs_embeds=None,
1335
- cache_position=None,
1336
- use_cache=True,
1337
- **kwargs,
1338
- ):
1339
- past_length = 0
1340
- if past_key_values is not None:
1341
- if isinstance(past_key_values, Cache):
1342
- past_length = cache_position[
1343
- 0] if cache_position is not None else past_key_values.get_seq_length(
1344
- )
1345
- max_cache_length = (
1346
- torch.tensor(
1347
- past_key_values.get_max_length(),
1348
- device=input_ids.device)
1349
- if past_key_values.get_max_length() is not None else None)
1350
- cache_length = past_length if max_cache_length is None else torch.min(
1351
- max_cache_length, past_length)
1352
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1353
- else:
1354
- cache_length = past_length = past_key_values[0][0].shape[2]
1355
- max_cache_length = None
1356
-
1357
- # Keep only the unprocessed tokens:
1358
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1359
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
1360
- if attention_mask is not None and attention_mask.shape[
1361
- 1] > input_ids.shape[1]:
1362
- input_ids = input_ids[:, -(attention_mask.shape[1] -
1363
- past_length):]
1364
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1365
- # input_ids based on the past_length.
1366
- elif past_length < input_ids.shape[1]:
1367
- input_ids = input_ids[:, past_length:]
1368
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1369
-
1370
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1371
- if (max_cache_length is not None and attention_mask is not None
1372
- and cache_length + input_ids.shape[1] > max_cache_length):
1373
- attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130
1374
-
1375
- position_ids = kwargs.get('position_ids', None)
1376
- if attention_mask is not None and position_ids is None:
1377
- # create position_ids on the fly for batch generation
1378
- position_ids = attention_mask.long().cumsum(-1) - 1
1379
- position_ids.masked_fill_(attention_mask == 0, 1)
1380
- if past_key_values:
1381
- position_ids = position_ids[:, -input_ids.shape[1]:]
1382
-
1383
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1384
- if inputs_embeds is not None and past_key_values is None:
1385
- model_inputs = {'inputs_embeds': inputs_embeds}
1386
- else:
1387
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1388
- # recompiles graphs as the stride of the inputs is a guard.
1389
- # Ref: https://github.com/huggingface/transformers/pull/29114
1390
- # TODO: use `next_tokens` directly instead.
1391
- model_inputs = {'input_ids': input_ids.contiguous()}
1392
-
1393
- input_length = position_ids.shape[
1394
- -1] if position_ids is not None else input_ids.shape[-1]
1395
- if cache_position is None:
1396
- cache_position = torch.arange(
1397
- past_length,
1398
- past_length + input_length,
1399
- device=input_ids.device)
1400
- elif use_cache:
1401
- cache_position = cache_position[-input_length:]
1402
-
1403
- model_inputs.update({
1404
- 'position_ids': position_ids,
1405
- 'cache_position': cache_position,
1406
- 'past_key_values': past_key_values,
1407
- 'use_cache': use_cache,
1408
- 'attention_mask': attention_mask,
1409
- })
1410
- return model_inputs
1411
-
1412
- @staticmethod
1413
- def _reorder_cache(past_key_values, beam_idx):
1414
- reordered_past = ()
1415
- for layer_past in past_key_values:
1416
- reordered_past += (tuple(
1417
- past_state.index_select(0, beam_idx.to(past_state.device))
1418
- for past_state in layer_past), )
1419
- return reordered_past
1420
-
1421
- def build_inputs(self,
1422
- tokenizer,
1423
- query: str,
1424
- history: List[Tuple[str, str]] = None,
1425
- meta_instruction=''):
1426
- if history is None:
1427
- history = []
1428
- if tokenizer.add_bos_token:
1429
- prompt = ''
1430
- else:
1431
- prompt = tokenizer.bos_token
1432
- if meta_instruction:
1433
- prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1434
- for record in history:
1435
- prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1436
- prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1437
- return tokenizer([prompt], return_tensors='pt')
1438
-
1439
- @torch.no_grad()
1440
- def chat(
1441
- self,
1442
- tokenizer,
1443
- query: str,
1444
- history: Optional[List[Tuple[str, str]]] = None,
1445
- streamer: Optional[BaseStreamer] = None,
1446
- max_new_tokens: int = 1024,
1447
- do_sample: bool = True,
1448
- temperature: float = 0.8,
1449
- top_p: float = 0.8,
1450
- meta_instruction:
1451
- str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
1452
- '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory '
1453
- '(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n'
1454
- '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such '
1455
- 'as English and 中文.',
1456
- **kwargs,
1457
- ):
1458
- if history is None:
1459
- history = []
1460
- inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1461
- inputs = {
1462
- k: v.to(self.device)
1463
- for k, v in inputs.items() if torch.is_tensor(v)
1464
- }
1465
- # also add end-of-assistant token in eos token id to avoid unnecessary generation
1466
- eos_token_id = [
1467
- tokenizer.eos_token_id,
1468
- tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]
1469
- ]
1470
- outputs = self.generate(
1471
- **inputs,
1472
- streamer=streamer,
1473
- max_new_tokens=max_new_tokens,
1474
- do_sample=do_sample,
1475
- temperature=temperature,
1476
- top_p=top_p,
1477
- eos_token_id=eos_token_id,
1478
- **kwargs,
1479
- )
1480
- outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
1481
- response = tokenizer.decode(outputs, skip_special_tokens=True)
1482
- response = response.split('<|im_end|>')[0]
1483
- history = history + [(query, response)]
1484
- return response, history
1485
-
1486
- @torch.no_grad()
1487
- def stream_chat(
1488
- self,
1489
- tokenizer,
1490
- query: str,
1491
- history: List[Tuple[str, str]] = None,
1492
- max_new_tokens: int = 1024,
1493
- do_sample: bool = True,
1494
- temperature: float = 0.8,
1495
- top_p: float = 0.8,
1496
- **kwargs,
1497
- ):
1498
- if history is None:
1499
- history = []
1500
- """
1501
- Return a generator in format: (response, history)
1502
- Eg.
1503
- ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1504
- ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1505
- """
1506
- if BaseStreamer is None:
1507
- raise ModuleNotFoundError(
1508
- 'The version of `transformers` is too low. Please make sure '
1509
- 'that you have installed `transformers>=4.28.0`.')
1510
-
1511
- response_queue = queue.Queue(maxsize=20)
1512
-
1513
- class ChatStreamer(BaseStreamer):
1514
- """
1515
- Streamer used in generate to print words one by one.
1516
- """
1517
-
1518
- def __init__(self, tokenizer) -> None:
1519
- super().__init__()
1520
- self.tokenizer = tokenizer
1521
- self.queue = response_queue
1522
- self.query = query
1523
- self.history = history
1524
- self.response = ''
1525
- self.cache = []
1526
- self.received_inputs = False
1527
- self.queue.put(
1528
- (self.response, history + [(self.query, self.response)]))
1529
-
1530
- def put(self, value):
1531
- if len(value.shape) > 1 and value.shape[0] > 1:
1532
- raise ValueError('ChatStreamer only supports batch size 1')
1533
- elif len(value.shape) > 1:
1534
- value = value[0]
1535
-
1536
- if not self.received_inputs:
1537
- # The first received value is input_ids, ignore here
1538
- self.received_inputs = True
1539
- return
1540
-
1541
- self.cache.extend(value.tolist())
1542
- token = self.tokenizer.decode(
1543
- self.cache, skip_special_tokens=True)
1544
- if token.strip() != '<|im_end|>':
1545
- self.response = self.response + token
1546
- history = self.history + [(self.query, self.response)]
1547
- self.queue.put((self.response, history))
1548
- self.cache = []
1549
- else:
1550
- self.end()
1551
-
1552
- def end(self):
1553
- self.queue.put(None)
1554
-
1555
- def stream_producer():
1556
- return self.chat(
1557
- tokenizer=tokenizer,
1558
- query=query,
1559
- streamer=ChatStreamer(tokenizer=tokenizer),
1560
- history=history,
1561
- max_new_tokens=max_new_tokens,
1562
- do_sample=do_sample,
1563
- temperature=temperature,
1564
- top_p=top_p,
1565
- **kwargs,
1566
- )
1567
-
1568
- def consumer():
1569
- producer = threading.Thread(target=stream_producer)
1570
- producer.start()
1571
- while True:
1572
- res = response_queue.get()
1573
- if res is None:
1574
- return
1575
- yield res
1576
-
1577
- return consumer()
1578
-
1579
-
1580
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1581
- @add_start_docstrings(
1582
- """
1583
- The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1584
- [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1585
- (e.g. GPT-2) do.
1586
- Since it does classification on the last token, it requires to know the position of the last token. If a
1587
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1588
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1589
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1590
- each row of the batch).
1591
- """,
1592
- InternLM2_START_DOCSTRING,
1593
- )
1594
- class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1595
- """Sequence Classification Head for InternLM2 Model."""
1596
-
1597
- def __init__(self, config):
1598
- super().__init__(config)
1599
- self.num_labels = config.num_labels
1600
- self.model = InternLM2Model(config)
1601
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1602
-
1603
- # Initialize weights and apply final processing
1604
- self.post_init()
1605
-
1606
- def get_input_embeddings(self):
1607
- return self.model.tok_embeddings
1608
-
1609
- def set_input_embeddings(self, value):
1610
- self.model.tok_embeddings = value
1611
-
1612
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1613
- def forward(
1614
- self,
1615
- input_ids: torch.LongTensor = None,
1616
- attention_mask: Optional[torch.Tensor] = None,
1617
- position_ids: Optional[torch.LongTensor] = None,
1618
- past_key_values: Optional[Union[Cache,
1619
- List[torch.FloatTensor]]] = None,
1620
- inputs_embeds: Optional[torch.FloatTensor] = None,
1621
- labels: Optional[torch.LongTensor] = None,
1622
- use_cache: Optional[bool] = None,
1623
- output_attentions: Optional[bool] = None,
1624
- output_hidden_states: Optional[bool] = None,
1625
- return_dict: Optional[bool] = None,
1626
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1627
- r"""
1628
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1629
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1630
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1631
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1632
- """
1633
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1634
-
1635
- transformer_outputs = self.model(
1636
- input_ids,
1637
- attention_mask=attention_mask,
1638
- position_ids=position_ids,
1639
- past_key_values=past_key_values,
1640
- inputs_embeds=inputs_embeds,
1641
- use_cache=use_cache,
1642
- output_attentions=output_attentions,
1643
- output_hidden_states=output_hidden_states,
1644
- return_dict=return_dict,
1645
- )
1646
- hidden_states = transformer_outputs[0]
1647
- logits = self.score(hidden_states)
1648
-
1649
- if input_ids is not None:
1650
- batch_size = input_ids.shape[0]
1651
- else:
1652
- batch_size = inputs_embeds.shape[0]
1653
-
1654
- if self.config.pad_token_id is None and batch_size != 1:
1655
- raise ValueError(
1656
- 'Cannot handle batch sizes > 1 if no padding token is defined.'
1657
- )
1658
- if self.config.pad_token_id is None:
1659
- sequence_lengths = -1
1660
- else:
1661
- if input_ids is not None:
1662
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1663
- sequence_lengths = torch.eq(
1664
- input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1665
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1666
- sequence_lengths = sequence_lengths.to(logits.device)
1667
- else:
1668
- sequence_lengths = -1
1669
-
1670
- pooled_logits = logits[torch.arange(batch_size, device=logits.device),
1671
- sequence_lengths]
1672
-
1673
- loss = None
1674
- if labels is not None:
1675
- labels = labels.to(logits.device)
1676
- if self.config.problem_type is None:
1677
- if self.num_labels == 1:
1678
- self.config.problem_type = 'regression'
1679
- elif self.num_labels > 1 and (labels.dtype
1680
- in (torch.long, torch.int)):
1681
- self.config.problem_type = 'single_label_classification'
1682
- else:
1683
- self.config.problem_type = 'multi_label_classification'
1684
-
1685
- if self.config.problem_type == 'regression':
1686
- loss_fct = MSELoss()
1687
- if self.num_labels == 1:
1688
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1689
- else:
1690
- loss = loss_fct(pooled_logits, labels)
1691
- elif self.config.problem_type == 'single_label_classification':
1692
- loss_fct = CrossEntropyLoss()
1693
- loss = loss_fct(
1694
- pooled_logits.view(-1, self.num_labels), labels.view(-1))
1695
- elif self.config.problem_type == 'multi_label_classification':
1696
- loss_fct = BCEWithLogitsLoss()
1697
- loss = loss_fct(pooled_logits, labels)
1698
- if not return_dict:
1699
- output = (pooled_logits, ) + transformer_outputs[1:]
1700
- return ((loss, ) + output) if loss is not None else output
1701
-
1702
- return SequenceClassifierOutputWithPast(
1703
- loss=loss,
1704
- logits=pooled_logits,
1705
- past_key_values=transformer_outputs.past_key_values,
1706
- hidden_states=transformer_outputs.hidden_states,
1707
- attentions=transformer_outputs.attentions,
1708
- )
1709
-
1710
-
1711
- # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->InternLM2
1712
- @add_start_docstrings(
1713
- """
1714
- The InternLM2 Model transformer with a span classification head on top for extractive question-answering tasks like
1715
- SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1716
- """,
1717
- InternLM2_START_DOCSTRING,
1718
- )
1719
- class InternLM2ForQuestionAnswering(InternLM2PreTrainedModel):
1720
- """Question Answering model for InternLM2."""
1721
-
1722
- base_model_prefix = 'transformer'
1723
-
1724
- def __init__(self, config):
1725
- super().__init__(config)
1726
- self.transformer = InternLM2Model(config)
1727
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
1728
-
1729
- # Initialize weights and apply final processing
1730
- self.post_init()
1731
-
1732
- def get_input_embeddings(self):
1733
- return self.transformer.tok_embeddings
1734
-
1735
- def set_input_embeddings(self, value):
1736
- self.transformer.tok_embeddings = value
1737
-
1738
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1739
- def forward(
1740
- self,
1741
- input_ids: Optional[torch.LongTensor] = None,
1742
- attention_mask: Optional[torch.FloatTensor] = None,
1743
- position_ids: Optional[torch.LongTensor] = None,
1744
- past_key_values: Optional[Union[Cache,
1745
- List[torch.FloatTensor]]] = None,
1746
- inputs_embeds: Optional[torch.FloatTensor] = None,
1747
- start_positions: Optional[torch.LongTensor] = None,
1748
- end_positions: Optional[torch.LongTensor] = None,
1749
- output_attentions: Optional[bool] = None,
1750
- output_hidden_states: Optional[bool] = None,
1751
- return_dict: Optional[bool] = None,
1752
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1753
- r"""
1754
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1755
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1756
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1757
- are not taken into account for computing the loss.
1758
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1759
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1760
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1761
- are not taken into account for computing the loss.
1762
- """
1763
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1764
-
1765
- outputs = self.transformer(
1766
- input_ids,
1767
- attention_mask=attention_mask,
1768
- position_ids=position_ids,
1769
- past_key_values=past_key_values,
1770
- inputs_embeds=inputs_embeds,
1771
- output_attentions=output_attentions,
1772
- output_hidden_states=output_hidden_states,
1773
- return_dict=return_dict,
1774
- )
1775
-
1776
- sequence_output = outputs[0]
1777
-
1778
- logits = self.qa_outputs(sequence_output)
1779
- start_logits, end_logits = logits.split(1, dim=-1)
1780
- start_logits = start_logits.squeeze(-1).contiguous()
1781
- end_logits = end_logits.squeeze(-1).contiguous()
1782
-
1783
- total_loss = None
1784
- if start_positions is not None and end_positions is not None:
1785
- # If we are on multi-GPU, split add a dimension
1786
- if len(start_positions.size()) > 1:
1787
- start_positions = start_positions.squeeze(-1).to(
1788
- start_logits.device)
1789
- if len(end_positions.size()) > 1:
1790
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
1791
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1792
- ignored_index = start_logits.size(1)
1793
- start_positions = start_positions.clamp(0, ignored_index)
1794
- end_positions = end_positions.clamp(0, ignored_index)
1795
-
1796
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1797
- start_loss = loss_fct(start_logits, start_positions)
1798
- end_loss = loss_fct(end_logits, end_positions)
1799
- total_loss = (start_loss + end_loss) / 2
1800
-
1801
- if not return_dict:
1802
- output = (start_logits, end_logits) + outputs[2:]
1803
- return ((total_loss, ) +
1804
- output) if total_loss is not None else output
1805
-
1806
- return QuestionAnsweringModelOutput(
1807
- loss=total_loss,
1808
- start_logits=start_logits,
1809
- end_logits=end_logits,
1810
- hidden_states=outputs.hidden_states,
1811
- attentions=outputs.attentions,
1812
- )
1813
-
1814
-
1815
- # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->InternLM2
1816
- @add_start_docstrings(
1817
- """
1818
- The InternLM2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1819
- output) e.g. for Named-Entity-Recognition (NER) tasks.
1820
- """,
1821
- InternLM2_START_DOCSTRING,
1822
- )
1823
- class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
1824
- """Token classification model for InternLM2."""
1825
-
1826
- def __init__(self, config):
1827
- super().__init__(config)
1828
- self.num_labels = config.num_labels
1829
- self.model = InternLM2Model(config)
1830
- if getattr(config, 'classifier_dropout', None) is not None:
1831
- classifier_dropout = config.classifier_dropout
1832
- elif getattr(config, 'hidden_dropout', None) is not None:
1833
- classifier_dropout = config.hidden_dropout
1834
- else:
1835
- classifier_dropout = 0.1
1836
- self.dropout = nn.Dropout(classifier_dropout)
1837
- self.score = nn.Linear(config.hidden_size, config.num_labels)
1838
-
1839
- # Initialize weights and apply final processing
1840
- self.post_init()
1841
-
1842
- def get_input_embeddings(self):
1843
- return self.model.tok_embeddings
1844
-
1845
- def set_input_embeddings(self, value):
1846
- self.model.tok_embeddings = value
1847
-
1848
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1849
- def forward(
1850
- self,
1851
- input_ids: torch.LongTensor = None,
1852
- attention_mask: Optional[torch.Tensor] = None,
1853
- position_ids: Optional[torch.LongTensor] = None,
1854
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1855
- inputs_embeds: Optional[torch.FloatTensor] = None,
1856
- labels: Optional[torch.LongTensor] = None,
1857
- use_cache: Optional[bool] = None,
1858
- output_attentions: Optional[bool] = None,
1859
- output_hidden_states: Optional[bool] = None,
1860
- return_dict: Optional[bool] = None,
1861
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1862
- r"""
1863
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1864
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1865
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1866
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1867
- """
1868
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1869
-
1870
- outputs = self.model(
1871
- input_ids,
1872
- attention_mask=attention_mask,
1873
- position_ids=position_ids,
1874
- past_key_values=past_key_values,
1875
- inputs_embeds=inputs_embeds,
1876
- use_cache=use_cache,
1877
- output_attentions=output_attentions,
1878
- output_hidden_states=output_hidden_states,
1879
- return_dict=return_dict,
1880
- )
1881
- sequence_output = outputs[0]
1882
- sequence_output = self.dropout(sequence_output)
1883
- logits = self.score(sequence_output)
1884
-
1885
- loss = None
1886
- if labels is not None:
1887
- loss_fct = CrossEntropyLoss()
1888
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1889
-
1890
- if not return_dict:
1891
- output = (logits, ) + outputs[2:]
1892
- return ((loss, ) + output) if loss is not None else output
1893
-
1894
- return TokenClassifierOutput(
1895
- loss=loss,
1896
- logits=logits,
1897
- hidden_states=outputs.hidden_states,
1898
- attentions=outputs.attentions,
1899
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/modelings/internlm3/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .configuration_internlm3 import InternLM3Config
2
- from .modeling_internlm3 import InternLM3ForCausalLM
3
- from .tokenization_internlm3 import InternLM3Tokenizer
 
 
 
 
code/xtuner/_lite/modelings/internlm3/configuration_internlm3.py DELETED
@@ -1,197 +0,0 @@
1
- # coding=utf-8
2
- # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- """ InternLM3 model configuration"""
18
-
19
- from transformers.configuration_utils import PretrainedConfig
20
- from transformers.modeling_rope_utils import rope_config_validation
21
- from transformers.utils import logging
22
-
23
-
24
- logger = logging.get_logger(__name__)
25
-
26
-
27
- class InternLM3Config(PretrainedConfig):
28
- r"""
29
- This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
30
- an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
31
- configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
32
-
33
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
- documentation from [`PretrainedConfig`] for more information.
35
-
36
-
37
- Args:
38
- vocab_size (`int`, *optional*, defaults to 151936):
39
- Vocabulary size of the InternLM3 model. Defines the number of different tokens that can be represented by the
40
- `inputs_ids` passed when calling [`InternLM3Model`]
41
- hidden_size (`int`, *optional*, defaults to 4096):
42
- Dimension of the hidden representations.
43
- intermediate_size (`int`, *optional*, defaults to 22016):
44
- Dimension of the MLP representations.
45
- num_hidden_layers (`int`, *optional*, defaults to 32):
46
- Number of hidden layers in the Transformer encoder.
47
- num_attention_heads (`int`, *optional*, defaults to 32):
48
- Number of attention heads for each attention layer in the Transformer encoder.
49
- num_key_value_heads (`int`, *optional*, defaults to 32):
50
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
51
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
52
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
53
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
54
- by meanpooling all the original heads within that group. For more details checkout [this
55
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
56
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
57
- The non-linear activation function (function or string) in the decoder.
58
- max_position_embeddings (`int`, *optional*, defaults to 32768):
59
- The maximum sequence length that this model might ever be used with.
60
- initializer_range (`float`, *optional*, defaults to 0.02):
61
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
63
- The epsilon used by the rms normalization layers.
64
- use_cache (`bool`, *optional*, defaults to `True`):
65
- Whether or not the model should return the last key/values attentions (not used by all models). Only
66
- relevant if `config.is_decoder=True`.
67
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
68
- Whether the model's input and output word embeddings should be tied.
69
- rope_theta (`float`, *optional*, defaults to 10000.0):
70
- The base period of the RoPE embeddings.
71
- rope_scaling (`Dict`, *optional*):
72
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
73
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
74
- accordingly.
75
- Expected contents:
76
- `rope_type` (`str`):
77
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
78
- 'llama3'], with 'default' being the original RoPE implementation.
79
- `factor` (`float`, *optional*):
80
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
81
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
82
- original maximum pre-trained length.
83
- `original_max_position_embeddings` (`int`, *optional*):
84
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
85
- pretraining.
86
- `attention_factor` (`float`, *optional*):
87
- Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
88
- computation. If unspecified, it defaults to value recommended by the implementation, using the
89
- `factor` field to infer the suggested value.
90
- `beta_fast` (`float`, *optional*):
91
- Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
92
- ramp function. If unspecified, it defaults to 32.
93
- `beta_slow` (`float`, *optional*):
94
- Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
95
- ramp function. If unspecified, it defaults to 1.
96
- `short_factor` (`List[float]`, *optional*):
97
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
98
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
99
- size divided by the number of attention heads divided by 2
100
- `long_factor` (`List[float]`, *optional*):
101
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
102
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
103
- size divided by the number of attention heads divided by 2
104
- `low_freq_factor` (`float`, *optional*):
105
- Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
106
- `high_freq_factor` (`float`, *optional*):
107
- Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
108
- qkv_bias (`bool`, *optional*, defaults to `False`):
109
- Whether to use a bias in the query, key and value projection layers during self-attention.
110
- attention_dropout (`float`, *optional*, defaults to 0.0):
111
- The dropout ratio for the attention probabilities.
112
- bias (`bool`, *optional*, defaults to `False`):
113
- Whether to use a bias in o_proj, up_proj, down_proj and gate_proj layers.
114
- head_dim (`int`, *optional*):
115
- The attention head dimension. If None, it will default to hidden_size // num_heads
116
-
117
- ```python
118
- >>> from transformers import InternLM3Model, InternLM3Config
119
-
120
- >>> # Initializing a InternLM3 style configuration
121
- >>> configuration = InternLM3Config()
122
-
123
- >>> # Initializing a model from the InternLM3-8B style configuration
124
- >>> model = InternLM3Model(configuration)
125
-
126
- >>> # Accessing the model configuration
127
- >>> configuration = model.config
128
- ```"""
129
-
130
- model_type = "internlm3"
131
- keys_to_ignore_at_inference = ["past_key_values"]
132
-
133
- # Default tensor parallel plan for base model `InternLM3`
134
- base_model_tp_plan = {
135
- "layers.*.self_attn.q_proj": "colwise",
136
- "layers.*.self_attn.k_proj": "colwise",
137
- "layers.*.self_attn.v_proj": "colwise",
138
- "layers.*.self_attn.o_proj": "rowwise",
139
- "layers.*.mlp.gate_proj": "colwise",
140
- "layers.*.mlp.up_proj": "colwise",
141
- "layers.*.mlp.down_proj": "rowwise",
142
- }
143
-
144
- def __init__(
145
- self,
146
- vocab_size=128512,
147
- hidden_size=4096,
148
- intermediate_size=11008,
149
- num_hidden_layers=32,
150
- num_attention_heads=32,
151
- num_key_value_heads=32,
152
- hidden_act="silu",
153
- max_position_embeddings=32768,
154
- initializer_range=0.02,
155
- rms_norm_eps=1e-6,
156
- use_cache=True,
157
- tie_word_embeddings=False,
158
- rope_theta=10000.0,
159
- rope_scaling=None,
160
- qkv_bias=False,
161
- attention_dropout=0.0,
162
- bias=False,
163
- head_dim=None,
164
- **kwargs,
165
- ):
166
- self.vocab_size = vocab_size
167
- self.max_position_embeddings = max_position_embeddings
168
- self.hidden_size = hidden_size
169
- self.intermediate_size = intermediate_size
170
- self.num_hidden_layers = num_hidden_layers
171
- self.num_attention_heads = num_attention_heads
172
-
173
- # for backward compatibility
174
- if num_key_value_heads is None:
175
- num_key_value_heads = num_attention_heads
176
-
177
- self.num_key_value_heads = num_key_value_heads
178
- self.hidden_act = hidden_act
179
- self.initializer_range = initializer_range
180
- self.rms_norm_eps = rms_norm_eps
181
- self.use_cache = use_cache
182
- self.rope_theta = rope_theta
183
- self.rope_scaling = rope_scaling
184
- self.qkv_bias = qkv_bias
185
- self.attention_dropout = attention_dropout
186
- self.bias = bias
187
- self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
188
- # Validate the correctness of rotary position embeddings parameters
189
- # BC: if there is a 'type' field, move it to 'rope_type'.
190
- if self.rope_scaling is not None and "type" in self.rope_scaling:
191
- self.rope_scaling["rope_type"] = self.rope_scaling["type"]
192
- rope_config_validation(self)
193
-
194
- super().__init__(
195
- tie_word_embeddings=tie_word_embeddings,
196
- **kwargs,
197
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/modelings/internlm3/modeling_internlm3.py DELETED
@@ -1,825 +0,0 @@
1
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
- # This file was automatically generated from src/transformers/models/internlm3/modular_internlm3.py.
3
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
- # the file from the modular. If any change should be done, please apply the change to the
5
- # modular_internlm3.py file directly. One of our CI enforces this.
6
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- from typing import Callable, List, Optional, Tuple, Union
8
-
9
- import torch
10
- from torch import nn
11
-
12
- from transformers.utils import logging
13
-
14
- from transformers.activations import ACT2FN
15
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
16
- from transformers.generation import GenerationMixin
17
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
18
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
20
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
21
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
22
- from transformers.processing_utils import Unpack
23
- from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
24
- from .configuration_internlm3 import InternLM3Config
25
-
26
-
27
- logger = logging.get_logger(__name__)
28
- _CONFIG_FOR_DOC = "InternLM3Config"
29
-
30
-
31
- class InternLM3MLP(nn.Module):
32
- def __init__(self, config):
33
- super().__init__()
34
- self.config = config
35
- self.hidden_size = config.hidden_size
36
- self.intermediate_size = config.intermediate_size
37
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
38
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
39
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
40
- self.act_fn = ACT2FN[config.hidden_act]
41
-
42
- def forward(self, x):
43
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
44
- return down_proj
45
-
46
-
47
- def rotate_half(x):
48
- """Rotates half the hidden dims of the input."""
49
- x1 = x[..., : x.shape[-1] // 2]
50
- x2 = x[..., x.shape[-1] // 2 :]
51
- return torch.cat((-x2, x1), dim=-1)
52
-
53
-
54
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
55
- """Applies Rotary Position Embedding to the query and key tensors.
56
-
57
- Args:
58
- q (`torch.Tensor`): The query tensor.
59
- k (`torch.Tensor`): The key tensor.
60
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
61
- sin (`torch.Tensor`): The sine part of the rotary embedding.
62
- position_ids (`torch.Tensor`, *optional*):
63
- Deprecated and unused.
64
- unsqueeze_dim (`int`, *optional*, defaults to 1):
65
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
66
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
67
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
68
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
69
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
70
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
71
- Returns:
72
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
73
- """
74
- cos = cos.unsqueeze(unsqueeze_dim)
75
- sin = sin.unsqueeze(unsqueeze_dim)
76
- q_embed = (q * cos) + (rotate_half(q) * sin)
77
- k_embed = (k * cos) + (rotate_half(k) * sin)
78
- return q_embed, k_embed
79
-
80
-
81
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
82
- """
83
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
84
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
85
- """
86
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
87
- if n_rep == 1:
88
- return hidden_states
89
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
90
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
91
-
92
-
93
- def eager_attention_forward(
94
- module: nn.Module,
95
- query: torch.Tensor,
96
- key: torch.Tensor,
97
- value: torch.Tensor,
98
- attention_mask: Optional[torch.Tensor],
99
- scaling: float,
100
- dropout: float = 0.0,
101
- **kwargs,
102
- ):
103
- key_states = repeat_kv(key, module.num_key_value_groups)
104
- value_states = repeat_kv(value, module.num_key_value_groups)
105
-
106
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
107
- if attention_mask is not None:
108
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
109
- attn_weights = attn_weights + causal_mask
110
-
111
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
112
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
113
- attn_output = torch.matmul(attn_weights, value_states)
114
- attn_output = attn_output.transpose(1, 2).contiguous()
115
-
116
- return attn_output, attn_weights
117
-
118
-
119
- class InternLM3Attention(nn.Module):
120
- """Multi-headed attention from 'Attention Is All You Need' paper"""
121
-
122
- def __init__(self, config: InternLM3Config, layer_idx: int):
123
- super().__init__()
124
- self.config = config
125
- self.layer_idx = layer_idx
126
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
127
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
128
- self.scaling = self.head_dim**-0.5
129
- self.attention_dropout = config.attention_dropout
130
- self.is_causal = True
131
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias)
132
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
133
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
134
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
135
-
136
- def forward(
137
- self,
138
- hidden_states: torch.Tensor,
139
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
140
- attention_mask: Optional[torch.Tensor],
141
- past_key_value: Optional[Cache] = None,
142
- cache_position: Optional[torch.LongTensor] = None,
143
- **kwargs: Unpack[FlashAttentionKwargs],
144
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
145
- input_shape = hidden_states.shape[:-1]
146
- hidden_shape = (*input_shape, -1, self.head_dim)
147
-
148
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
149
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
150
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
151
-
152
- cos, sin = position_embeddings
153
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
154
-
155
- if past_key_value is not None:
156
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
157
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
158
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
159
-
160
- attention_interface: Callable = eager_attention_forward
161
- if self.config._attn_implementation != "eager":
162
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
163
- logger.warning_once(
164
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
165
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
166
- )
167
- else:
168
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
169
-
170
- attn_output, attn_weights = attention_interface(
171
- self,
172
- query_states,
173
- key_states,
174
- value_states,
175
- attention_mask,
176
- dropout=0.0 if not self.training else self.attention_dropout,
177
- scaling=self.scaling,
178
- **kwargs,
179
- )
180
-
181
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
182
- attn_output = self.o_proj(attn_output)
183
- return attn_output, attn_weights
184
-
185
-
186
- class InternLM3RMSNorm(nn.Module):
187
- def __init__(self, hidden_size, eps=1e-6):
188
- """
189
- InternLM3RMSNorm is equivalent to T5LayerNorm
190
- """
191
- super().__init__()
192
- self.weight = nn.Parameter(torch.ones(hidden_size))
193
- self.variance_epsilon = eps
194
-
195
- def forward(self, hidden_states):
196
- input_dtype = hidden_states.dtype
197
- hidden_states = hidden_states.to(torch.float32)
198
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
199
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
200
- return self.weight * hidden_states.to(input_dtype)
201
-
202
- def extra_repr(self):
203
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
204
-
205
-
206
- class InternLM3DecoderLayer(nn.Module):
207
- def __init__(self, config: InternLM3Config, layer_idx: int):
208
- super().__init__()
209
- self.hidden_size = config.hidden_size
210
- self.self_attn = InternLM3Attention(config=config, layer_idx=layer_idx)
211
- self.mlp = InternLM3MLP(config)
212
- self.input_layernorm = InternLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
213
- self.post_attention_layernorm = InternLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
214
-
215
- def forward(
216
- self,
217
- hidden_states: torch.Tensor,
218
- attention_mask: Optional[torch.Tensor] = None,
219
- position_ids: Optional[torch.LongTensor] = None,
220
- past_key_value: Optional[Cache] = None,
221
- output_attentions: Optional[bool] = False,
222
- use_cache: Optional[bool] = False,
223
- cache_position: Optional[torch.LongTensor] = None,
224
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
225
- **kwargs: Unpack[FlashAttentionKwargs],
226
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
227
- residual = hidden_states
228
-
229
- hidden_states = self.input_layernorm(hidden_states)
230
-
231
- # Self Attention
232
- hidden_states, self_attn_weights = self.self_attn(
233
- hidden_states=hidden_states,
234
- attention_mask=attention_mask,
235
- position_ids=position_ids,
236
- past_key_value=past_key_value,
237
- output_attentions=output_attentions,
238
- use_cache=use_cache,
239
- cache_position=cache_position,
240
- position_embeddings=position_embeddings,
241
- **kwargs,
242
- )
243
- hidden_states = residual + hidden_states
244
-
245
- # Fully Connected
246
- residual = hidden_states
247
- hidden_states = self.post_attention_layernorm(hidden_states)
248
- hidden_states = self.mlp(hidden_states)
249
- hidden_states = residual + hidden_states
250
-
251
- outputs = (hidden_states,)
252
- if output_attentions:
253
- outputs += (self_attn_weights,)
254
-
255
- return outputs
256
-
257
-
258
- class InternLM3RotaryEmbedding(nn.Module):
259
- def __init__(self, config: InternLM3Config, device=None):
260
- super().__init__()
261
- # BC: "rope_type" was originally "type"
262
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
263
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
264
- else:
265
- self.rope_type = "default"
266
- self.max_seq_len_cached = config.max_position_embeddings
267
- self.original_max_seq_len = config.max_position_embeddings
268
-
269
- self.config = config
270
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
271
-
272
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
273
- self.register_buffer("inv_freq", inv_freq, persistent=False)
274
- self.original_inv_freq = self.inv_freq
275
-
276
- def _dynamic_frequency_update(self, position_ids, device):
277
- """
278
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
279
- 1 - growing beyond the cached sequence length (allow scaling)
280
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
281
- """
282
- seq_len = torch.max(position_ids) + 1
283
- if seq_len > self.max_seq_len_cached: # growth
284
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
285
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
286
- self.max_seq_len_cached = seq_len
287
-
288
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
289
- # This .to() is needed if the model has been moved to a device after being initialized (because
290
- # the buffer is automatically moved, but not the original copy)
291
- self.original_inv_freq = self.original_inv_freq.to(device)
292
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
293
- self.max_seq_len_cached = self.original_max_seq_len
294
-
295
- @torch.no_grad()
296
- def forward(self, x, position_ids):
297
- if "dynamic" in self.rope_type:
298
- self._dynamic_frequency_update(position_ids, device=x.device)
299
-
300
- # Core RoPE block
301
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
302
- position_ids_expanded = position_ids[:, None, :].float()
303
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
304
- device_type = x.device.type
305
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
306
- with torch.autocast(device_type=device_type, enabled=False):
307
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
308
- emb = torch.cat((freqs, freqs), dim=-1)
309
- cos = emb.cos()
310
- sin = emb.sin()
311
-
312
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
313
- cos = cos * self.attention_scaling
314
- sin = sin * self.attention_scaling
315
-
316
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
317
-
318
-
319
- INTERNLM3_START_DOCSTRING = r"""
320
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
321
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
322
- etc.)
323
-
324
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
325
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
326
- and behavior.
327
-
328
- Parameters:
329
- config ([`InternLM3Config`]):
330
- Model configuration class with all the parameters of the model. Initializing with a config file does not
331
- load the weights associated with the model, only the configuration. Check out the
332
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
333
- """
334
-
335
-
336
- @add_start_docstrings(
337
- "The bare InternLM3 Model outputting raw hidden-states without any specific head on top.",
338
- INTERNLM3_START_DOCSTRING,
339
- )
340
- class InternLM3PreTrainedModel(PreTrainedModel):
341
- config_class = InternLM3Config
342
- base_model_prefix = "model"
343
- supports_gradient_checkpointing = True
344
- _no_split_modules = ["InternLM3DecoderLayer"]
345
- _skip_keys_device_placement = ["past_key_values"]
346
- _supports_flash_attn_2 = True
347
- _supports_sdpa = True
348
- _supports_flex_attn = True
349
- _supports_cache_class = True
350
- _supports_quantized_cache = True
351
- _supports_static_cache = True
352
-
353
- def _init_weights(self, module):
354
- std = self.config.initializer_range
355
- if isinstance(module, nn.Linear):
356
- module.weight.data.normal_(mean=0.0, std=std)
357
- if module.bias is not None:
358
- module.bias.data.zero_()
359
- elif isinstance(module, nn.Embedding):
360
- module.weight.data.normal_(mean=0.0, std=std)
361
- if module.padding_idx is not None:
362
- module.weight.data[module.padding_idx].zero_()
363
-
364
-
365
- INTERNLM3_INPUTS_DOCSTRING = r"""
366
- Args:
367
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
368
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
369
- it.
370
-
371
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
372
- [`PreTrainedTokenizer.__call__`] for details.
373
-
374
- [What are input IDs?](../glossary#input-ids)
375
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
376
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
377
-
378
- - 1 for tokens that are **not masked**,
379
- - 0 for tokens that are **masked**.
380
-
381
- [What are attention masks?](../glossary#attention-mask)
382
-
383
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
384
- [`PreTrainedTokenizer.__call__`] for details.
385
-
386
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
387
- `past_key_values`).
388
-
389
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
390
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
391
- information on the default strategy.
392
-
393
- - 1 indicates the head is **not masked**,
394
- - 0 indicates the head is **masked**.
395
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
396
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
397
- config.n_positions - 1]`.
398
-
399
- [What are position IDs?](../glossary#position-ids)
400
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
401
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
402
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
403
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
404
-
405
- Two formats are allowed:
406
- - a [`~cache_utils.Cache`] instance, see our
407
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
408
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
409
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
410
- cache format.
411
-
412
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
413
- legacy cache format will be returned.
414
-
415
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
416
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
417
- of shape `(batch_size, sequence_length)`.
418
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
419
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
420
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
421
- model's internal embedding lookup matrix.
422
- use_cache (`bool`, *optional*):
423
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
424
- `past_key_values`).
425
- output_attentions (`bool`, *optional*):
426
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
427
- tensors for more detail.
428
- output_hidden_states (`bool`, *optional*):
429
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
430
- more detail.
431
- return_dict (`bool`, *optional*):
432
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
433
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
434
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
435
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
436
- the complete sequence length.
437
- """
438
-
439
-
440
- @add_start_docstrings(
441
- "The bare InternLM3 Model outputting raw hidden-states without any specific head on top.",
442
- INTERNLM3_START_DOCSTRING,
443
- )
444
- class InternLM3Model(InternLM3PreTrainedModel):
445
- """
446
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM3DecoderLayer`]
447
-
448
- Args:
449
- config: InternLM3Config
450
- """
451
-
452
- def __init__(self, config: InternLM3Config):
453
- super().__init__(config)
454
- self.padding_idx = config.pad_token_id
455
- self.vocab_size = config.vocab_size
456
-
457
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
458
- self.layers = nn.ModuleList(
459
- [InternLM3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
460
- )
461
- self.norm = InternLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
- self.rotary_emb = InternLM3RotaryEmbedding(config=config)
463
- self.gradient_checkpointing = False
464
-
465
- # Initialize weights and apply final processing
466
- self.post_init()
467
-
468
- def get_input_embeddings(self):
469
- return self.embed_tokens
470
-
471
- def set_input_embeddings(self, value):
472
- self.embed_tokens = value
473
-
474
- @add_start_docstrings_to_model_forward(INTERNLM3_INPUTS_DOCSTRING)
475
- def forward(
476
- self,
477
- input_ids: torch.LongTensor = None,
478
- attention_mask: Optional[torch.Tensor] = None,
479
- position_ids: Optional[torch.LongTensor] = None,
480
- past_key_values: Optional[Cache] = None,
481
- inputs_embeds: Optional[torch.FloatTensor] = None,
482
- use_cache: Optional[bool] = None,
483
- output_attentions: Optional[bool] = None,
484
- output_hidden_states: Optional[bool] = None,
485
- return_dict: Optional[bool] = None,
486
- cache_position: Optional[torch.LongTensor] = None,
487
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
488
- ) -> Union[Tuple, BaseModelOutputWithPast]:
489
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
- output_hidden_states = (
491
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
- )
493
- use_cache = use_cache if use_cache is not None else self.config.use_cache
494
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
495
-
496
- if (input_ids is None) ^ (inputs_embeds is not None):
497
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
498
-
499
- if self.gradient_checkpointing and self.training and use_cache:
500
- logger.warning_once(
501
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
502
- )
503
- use_cache = False
504
-
505
- if inputs_embeds is None:
506
- inputs_embeds = self.embed_tokens(input_ids)
507
-
508
- if use_cache and past_key_values is None:
509
- past_key_values = DynamicCache()
510
-
511
- if cache_position is None:
512
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
513
- cache_position = torch.arange(
514
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
515
- )
516
-
517
- if position_ids is None:
518
- position_ids = cache_position.unsqueeze(0)
519
-
520
- causal_mask = self._update_causal_mask(
521
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
522
- )
523
-
524
- hidden_states = inputs_embeds
525
-
526
- # create position embeddings to be shared across the decoder layers
527
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
528
-
529
- # decoder layers
530
- all_hidden_states = () if output_hidden_states else None
531
- all_self_attns = () if output_attentions else None
532
-
533
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
534
- if output_hidden_states:
535
- all_hidden_states += (hidden_states,)
536
-
537
- if self.gradient_checkpointing and self.training:
538
- layer_outputs = self._gradient_checkpointing_func(
539
- decoder_layer.__call__,
540
- hidden_states,
541
- causal_mask,
542
- position_ids,
543
- past_key_values,
544
- output_attentions,
545
- use_cache,
546
- cache_position,
547
- position_embeddings,
548
- )
549
- else:
550
- layer_outputs = decoder_layer(
551
- hidden_states,
552
- attention_mask=causal_mask,
553
- position_ids=position_ids,
554
- past_key_value=past_key_values,
555
- output_attentions=output_attentions,
556
- use_cache=use_cache,
557
- cache_position=cache_position,
558
- position_embeddings=position_embeddings,
559
- **flash_attn_kwargs,
560
- )
561
-
562
- hidden_states = layer_outputs[0]
563
-
564
- if output_attentions:
565
- all_self_attns += (layer_outputs[1],)
566
-
567
- hidden_states = self.norm(hidden_states)
568
-
569
- # add hidden states from the last decoder layer
570
- if output_hidden_states:
571
- all_hidden_states += (hidden_states,)
572
-
573
- output = BaseModelOutputWithPast(
574
- last_hidden_state=hidden_states,
575
- past_key_values=past_key_values if use_cache else None,
576
- hidden_states=all_hidden_states,
577
- attentions=all_self_attns,
578
- )
579
- return output if return_dict else output.to_tuple()
580
-
581
- def _update_causal_mask(
582
- self,
583
- attention_mask: torch.Tensor,
584
- input_tensor: torch.Tensor,
585
- cache_position: torch.Tensor,
586
- past_key_values: Cache,
587
- output_attentions: bool,
588
- ):
589
- if self.config._attn_implementation == "flash_attention_2":
590
- if attention_mask is not None and (attention_mask == 0.0).any():
591
- return attention_mask
592
- return None
593
-
594
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
595
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
596
- # to infer the attention mask.
597
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
598
- using_static_cache = isinstance(past_key_values, StaticCache)
599
-
600
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
601
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
602
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
603
- attention_mask,
604
- inputs_embeds=input_tensor,
605
- past_key_values_length=past_seen_tokens,
606
- is_training=self.training,
607
- ):
608
- return None
609
-
610
- dtype, device = input_tensor.dtype, input_tensor.device
611
- sequence_length = input_tensor.shape[1]
612
- if using_static_cache:
613
- target_length = past_key_values.get_max_cache_shape()
614
- else:
615
- target_length = (
616
- attention_mask.shape[-1]
617
- if isinstance(attention_mask, torch.Tensor)
618
- else past_seen_tokens + sequence_length + 1
619
- )
620
-
621
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
622
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
623
- attention_mask,
624
- sequence_length=sequence_length,
625
- target_length=target_length,
626
- dtype=dtype,
627
- device=device,
628
- cache_position=cache_position,
629
- batch_size=input_tensor.shape[0],
630
- )
631
-
632
- if (
633
- self.config._attn_implementation == "sdpa"
634
- and attention_mask is not None
635
- and attention_mask.device.type == "cuda"
636
- and not output_attentions
637
- ):
638
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
639
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
640
- # Details: https://github.com/pytorch/pytorch/issues/110213
641
- min_dtype = torch.finfo(dtype).min
642
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
643
-
644
- return causal_mask
645
-
646
- @staticmethod
647
- def _prepare_4d_causal_attention_mask_with_cache_position(
648
- attention_mask: torch.Tensor,
649
- sequence_length: int,
650
- target_length: int,
651
- dtype: torch.dtype,
652
- device: torch.device,
653
- cache_position: torch.Tensor,
654
- batch_size: int,
655
- **kwargs,
656
- ):
657
- """
658
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
659
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
660
-
661
- Args:
662
- attention_mask (`torch.Tensor`):
663
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
664
- `(batch_size, 1, query_length, key_value_length)`.
665
- sequence_length (`int`):
666
- The sequence length being processed.
667
- target_length (`int`):
668
- The target length: when generating with static cache, the mask should be as long as the static cache,
669
- to account for the 0 padding, the part of the cache that is not filled yet.
670
- dtype (`torch.dtype`):
671
- The dtype to use for the 4D attention mask.
672
- device (`torch.device`):
673
- The device to plcae the 4D attention mask on.
674
- cache_position (`torch.Tensor`):
675
- Indices depicting the position of the input sequence tokens in the sequence.
676
- batch_size (`torch.Tensor`):
677
- Batch size.
678
- """
679
- if attention_mask is not None and attention_mask.dim() == 4:
680
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
681
- causal_mask = attention_mask
682
- else:
683
- min_dtype = torch.finfo(dtype).min
684
- causal_mask = torch.full(
685
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
686
- )
687
- if sequence_length != 1:
688
- causal_mask = torch.triu(causal_mask, diagonal=1)
689
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
690
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
691
- if attention_mask is not None:
692
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
693
- mask_length = attention_mask.shape[-1]
694
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
695
- padding_mask = padding_mask == 0
696
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
697
- padding_mask, min_dtype
698
- )
699
-
700
- return causal_mask
701
-
702
-
703
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
704
-
705
-
706
- class InternLM3ForCausalLM(InternLM3PreTrainedModel, GenerationMixin):
707
- _auto_class = 'AutoModelForCausalLM'
708
- _tied_weights_keys = ["lm_head.weight"]
709
- _tp_plan = {"lm_head": "colwise_rep"}
710
-
711
- def __init__(self, config):
712
- super().__init__(config)
713
- self.model = InternLM3Model(config)
714
- self.vocab_size = config.vocab_size
715
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
716
-
717
- # Initialize weights and apply final processing
718
- self.post_init()
719
-
720
- def get_input_embeddings(self):
721
- return self.model.embed_tokens
722
-
723
- def set_input_embeddings(self, value):
724
- self.model.embed_tokens = value
725
-
726
- def get_output_embeddings(self):
727
- return self.lm_head
728
-
729
- def set_output_embeddings(self, new_embeddings):
730
- self.lm_head = new_embeddings
731
-
732
- def set_decoder(self, decoder):
733
- self.model = decoder
734
-
735
- def get_decoder(self):
736
- return self.model
737
-
738
- @add_start_docstrings_to_model_forward(INTERNLM3_INPUTS_DOCSTRING)
739
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
740
- def forward(
741
- self,
742
- input_ids: torch.LongTensor = None,
743
- attention_mask: Optional[torch.Tensor] = None,
744
- position_ids: Optional[torch.LongTensor] = None,
745
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
746
- inputs_embeds: Optional[torch.FloatTensor] = None,
747
- labels: Optional[torch.LongTensor] = None,
748
- use_cache: Optional[bool] = None,
749
- output_attentions: Optional[bool] = None,
750
- output_hidden_states: Optional[bool] = None,
751
- return_dict: Optional[bool] = None,
752
- cache_position: Optional[torch.LongTensor] = None,
753
- num_logits_to_keep: int = 0,
754
- **kwargs: Unpack[KwargsForCausalLM],
755
- ) -> Union[Tuple, CausalLMOutputWithPast]:
756
- r"""
757
- Args:
758
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
759
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
760
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
761
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
762
-
763
- num_logits_to_keep (`int`, *optional*):
764
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
765
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
766
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
767
-
768
- Returns:
769
-
770
- Example:
771
-
772
- ```python
773
- >>> from transformers import AutoTokenizer, InternLM3ForCausalLM
774
-
775
- >>> model = InternLM3ForCausalLM.from_pretrained("meta-internlm3/InternLM3-2-7b-hf")
776
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-internlm3/InternLM3-2-7b-hf")
777
-
778
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
779
- >>> inputs = tokenizer(prompt, return_tensors="pt")
780
-
781
- >>> # Generate
782
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
783
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
784
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
785
- ```"""
786
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
787
- output_hidden_states = (
788
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
789
- )
790
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
791
-
792
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
793
- outputs = self.model(
794
- input_ids=input_ids,
795
- attention_mask=attention_mask,
796
- position_ids=position_ids,
797
- past_key_values=past_key_values,
798
- inputs_embeds=inputs_embeds,
799
- use_cache=use_cache,
800
- output_attentions=output_attentions,
801
- output_hidden_states=output_hidden_states,
802
- return_dict=return_dict,
803
- cache_position=cache_position,
804
- **kwargs,
805
- )
806
-
807
- hidden_states = outputs[0]
808
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
809
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
810
-
811
- loss = None
812
- if labels is not None:
813
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
814
-
815
- if not return_dict:
816
- output = (logits,) + outputs[1:]
817
- return (loss,) + output if loss is not None else output
818
-
819
- return CausalLMOutputWithPast(
820
- loss=loss,
821
- logits=logits,
822
- past_key_values=outputs.past_key_values,
823
- hidden_states=outputs.hidden_states,
824
- attentions=outputs.attentions,
825
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/xtuner/_lite/modelings/internlm3/tokenization_internlm3.py DELETED
@@ -1,295 +0,0 @@
1
- import os
2
- from shutil import copyfile
3
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
4
-
5
- import sentencepiece as spm
6
- from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
7
- from transformers.utils import logging
8
-
9
- if TYPE_CHECKING:
10
- from transformers.tokenization_utils_base import TextInput
11
-
12
- logger = logging.get_logger(__name__)
13
-
14
- VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
15
-
16
- SPIECE_UNDERLINE = "▁"
17
-
18
-
19
- class InternLM3Tokenizer(PreTrainedTokenizer):
20
- """
21
- Construct a InternLM3 tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
22
- no padding token in the original model.
23
-
24
- Args:
25
- vocab_file (`str`):
26
- Path to the vocabulary file.
27
- unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
28
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
29
- token instead.
30
- bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
31
- The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
32
- eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
33
- The end of sequence token.
34
- pad_token (`str` or `tokenizers.AddedToken`, *optional*):
35
- A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
36
- attention mechanisms or loss computation.
37
- sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
38
- Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
39
- SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
40
- to set:
41
-
42
- - `enable_sampling`: Enable subword regularization.
43
- - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
44
-
45
- - `nbest_size = {0,1}`: No sampling is performed.
46
- - `nbest_size > 1`: samples from the nbest_size results.
47
- - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
48
- using forward-filtering-and-backward-sampling algorithm.
49
-
50
- - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
51
- BPE-dropout.
52
-
53
- add_bos_token (`bool`, *optional*, defaults to `True`):
54
- Whether or not to add an `bos_token` at the start of sequences.
55
- add_eos_token (`bool`, *optional*, defaults to `False`):
56
- Whether or not to add an `eos_token` at the end of sequences.
57
- clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
58
- Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
59
- extra spaces.
60
- use_default_system_prompt (`bool`, *optional*, defaults to `False`):
61
- Whether or not the default system prompt for InternLM3 should be used.
62
- spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
63
- Whether or not to add spaces between special tokens.
64
- spaces_for_interleaved_special_tokens (`bool`, *optional*, defaults to `False`):
65
- Whether or not to add spaces between special tokens that are interleaved with normal tokens.
66
- add_prefix_space (`bool`, *optional*, defaults to `True`):
67
- Whether or not to add an initial space to the input. This allows to treat the leading word just as any
68
- other word. Again, this should be set with `from_slow=True` to make sure it's taken into account.
69
- """
70
-
71
- vocab_files_names = VOCAB_FILES_NAMES
72
- model_input_names = ["input_ids", "attention_mask"]
73
- _auto_class = "AutoTokenizer"
74
-
75
- def __init__(
76
- self,
77
- vocab_file,
78
- unk_token="<unk>",
79
- bos_token="<s>",
80
- eos_token="</s>",
81
- pad_token=None,
82
- sp_model_kwargs: Optional[Dict[str, Any]] = None,
83
- add_bos_token=True,
84
- add_eos_token=False,
85
- clean_up_tokenization_spaces=False,
86
- use_default_system_prompt=False,
87
- spaces_between_special_tokens=False,
88
- spaces_for_interleaved_special_tokens=False,
89
- add_prefix_space=True,
90
- **kwargs,
91
- ):
92
- self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
93
- bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
94
- eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
95
- unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
96
- pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
97
-
98
- self.vocab_file = vocab_file
99
- self.add_bos_token = add_bos_token
100
- self.add_eos_token = add_eos_token
101
- self.use_default_system_prompt = use_default_system_prompt
102
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
103
- self.sp_model.Load(vocab_file)
104
- self.add_prefix_space = add_prefix_space
105
- self.spaces_for_interleaved_special_tokens = spaces_for_interleaved_special_tokens
106
-
107
- vocab_size = self.sp_model.get_piece_size()
108
- self.decoder = {i: self.sp_model.id_to_piece(i) for i in range(vocab_size)}
109
-
110
- super().__init__(
111
- bos_token=bos_token,
112
- eos_token=eos_token,
113
- unk_token=unk_token,
114
- pad_token=pad_token,
115
- add_bos_token=add_bos_token,
116
- add_eos_token=add_eos_token,
117
- sp_model_kwargs=sp_model_kwargs,
118
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
119
- use_default_system_prompt=use_default_system_prompt,
120
- spaces_between_special_tokens=spaces_between_special_tokens,
121
- add_prefix_space=add_prefix_space,
122
- **kwargs,
123
- )
124
-
125
- def __getstate__(self):
126
- state = self.__dict__.copy()
127
- state["sp_model"] = None
128
- state["sp_model_proto"] = self.sp_model.serialized_model_proto()
129
- return state
130
-
131
- def __setstate__(self, d):
132
- self.__dict__.update(d)
133
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
134
- self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
135
-
136
- @property
137
- def vocab_size(self):
138
- """Returns vocab size"""
139
- return self.sp_model.get_piece_size()
140
-
141
- def get_vocab(self):
142
- """Returns vocab as a dict"""
143
- vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
144
- vocab.update(self.added_tokens_encoder)
145
- return vocab
146
-
147
- def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
148
- """
149
- Args:
150
- text: TextInput
151
- Simply calls PreTrainedTokenizer's method
152
- """
153
- return super().tokenize(text, **kwargs)
154
-
155
- def _tokenize(self, text, **kwargs):
156
- """
157
- Args:
158
- text: TextInput
159
- Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
160
- """
161
- return self.sp_model.encode(text, out_type=str)
162
-
163
- def _convert_token_to_id(self, token):
164
- """Converts a token (str) in an id using the vocab."""
165
- return self.sp_model.piece_to_id(token)
166
-
167
- def _convert_id_to_token(self, index):
168
- """Converts an index (integer) in a token (str) using the vocab."""
169
- return self.decoder.get(index, "")
170
-
171
- def convert_tokens_to_string(self, tokens):
172
- """Converts a sequence of tokens (string) in a single string."""
173
- # since we manually add the prefix space, we have to remove it when decoding
174
- if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
175
- tokens[0] = tokens[0][1:]
176
-
177
- current_sub_tokens = []
178
- out_string = ""
179
- prev_is_special = False
180
- for i, token in enumerate(tokens):
181
- # make sure that special tokens are not decoded using sentencepiece model
182
- if token in self.all_special_tokens:
183
- if not prev_is_special and i != 0 and self.spaces_for_interleaved_special_tokens:
184
- out_string += " "
185
- out_string += self.sp_model.decode(current_sub_tokens) + token
186
- prev_is_special = True
187
- current_sub_tokens = []
188
- else:
189
- if (
190
- prev_is_special
191
- and i == 1
192
- and self.add_prefix_space
193
- and not token.startswith(SPIECE_UNDERLINE)
194
- and self.spaces_for_interleaved_special_tokens
195
- ):
196
- out_string += " "
197
- current_sub_tokens.append(token)
198
- prev_is_special = False
199
- out_string += self.sp_model.decode(current_sub_tokens)
200
- return out_string
201
-
202
- def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
203
- """
204
- Save the vocabulary and special tokens file to a directory.
205
-
206
- Args:
207
- save_directory (`str`):
208
- The directory in which to save the vocabulary.
209
-
210
- Returns:
211
- `Tuple(str)`: Paths to the files saved.
212
- """
213
- if not os.path.isdir(save_directory):
214
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
215
- return
216
- out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"])
217
-
218
- if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
219
- copyfile(self.vocab_file, out_vocab_file)
220
- elif not os.path.isfile(self.vocab_file):
221
- with open(out_vocab_file, "wb") as fi:
222
- content_spiece_model = self.sp_model.serialized_model_proto()
223
- fi.write(content_spiece_model)
224
-
225
- return (out_vocab_file,)
226
-
227
- def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
228
- bos_token_id = [self.bos_token_id] if self.add_bos_token else []
229
- eos_token_id = [self.eos_token_id] if self.add_eos_token else []
230
-
231
- output = bos_token_id + token_ids_0 + eos_token_id
232
-
233
- if token_ids_1 is not None:
234
- output = output + bos_token_id + token_ids_1 + eos_token_id
235
-
236
- return output
237
-
238
- def get_special_tokens_mask(
239
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
240
- ) -> List[int]:
241
- """
242
- Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
243
- special tokens using the tokenizer `prepare_for_model` method.
244
-
245
- Args:
246
- token_ids_0 (`List[int]`):
247
- List of IDs.
248
- token_ids_1 (`List[int]`, *optional*):
249
- Optional second list of IDs for sequence pairs.
250
- already_has_special_tokens (`bool`, *optional*, defaults to `False`):
251
- Whether or not the token list is already formatted with special tokens for the model.
252
-
253
- Returns:
254
- `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
255
- """
256
- if already_has_special_tokens:
257
- return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)
258
-
259
- bos_token_id = [1] if self.add_bos_token else []
260
- eos_token_id = [1] if self.add_eos_token else []
261
-
262
- if token_ids_1 is None:
263
- return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
264
- return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id
265
-
266
- def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
267
- """
268
- Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
269
- sequence pair mask has the following format:
270
-
271
- ```
272
- 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
273
- | first sequence | second sequence |
274
- ```
275
-
276
- if token_ids_1 is None, only returns the first portion of the mask (0s).
277
-
278
- Args:
279
- token_ids_0 (`List[int]`):
280
- List of ids.
281
- token_ids_1 (`List[int]`, *optional*):
282
- Optional second list of IDs for sequence pairs.
283
-
284
- Returns:
285
- `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
286
- """
287
- bos_token_id = [self.bos_token_id] if self.add_bos_token else []
288
- eos_token_id = [self.eos_token_id] if self.add_eos_token else []
289
-
290
- output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
291
-
292
- if token_ids_1 is not None:
293
- output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
294
-
295
- return output