fix bug
Browse files- config.json +5 -6
- configuration.json +3 -1
- modeling_qwen.py +8 -9
- qwen_generation_utils.py +0 -1
- tokenization_qwen.py +123 -138
- tokenizer_config.json +1 -1
config.json
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "10302244_iter8000_final/",
|
| 3 |
"architectures": [
|
| 4 |
"QWenLMHeadModel"
|
| 5 |
],
|
| 6 |
"attn_dropout_prob": 0.0,
|
| 7 |
"audio": {
|
| 8 |
"add_audio_bos_eos_token": true,
|
| 9 |
-
"audio_start_id":
|
| 10 |
"avg_pool": true,
|
| 11 |
"n_ctx": 1500,
|
| 12 |
"n_head": 20,
|
|
@@ -19,7 +18,7 @@
|
|
| 19 |
"AutoConfig": "configuration_qwen.QWenConfig",
|
| 20 |
"AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
|
| 21 |
},
|
| 22 |
-
"bf16":
|
| 23 |
"emb_dropout_prob": 0.0,
|
| 24 |
"fp16": false,
|
| 25 |
"fp32": false,
|
|
@@ -27,8 +26,8 @@
|
|
| 27 |
"initializer_range": 0.02,
|
| 28 |
"intermediate_size": 22016,
|
| 29 |
"kv_channels": 128,
|
| 30 |
-
"layer_norm_epsilon": 1e-
|
| 31 |
-
"max_position_embeddings":
|
| 32 |
"model_type": "qwen",
|
| 33 |
"no_bias": true,
|
| 34 |
"num_attention_heads": 32,
|
|
@@ -47,7 +46,7 @@
|
|
| 47 |
"use_cache_kernel": false,
|
| 48 |
"use_cache_quantization": false,
|
| 49 |
"use_dynamic_ntk": true,
|
| 50 |
-
"use_flash_attn":
|
| 51 |
"use_logn_attn": true,
|
| 52 |
"vocab_size": 155947
|
| 53 |
}
|
|
|
|
| 1 |
{
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"QWenLMHeadModel"
|
| 4 |
],
|
| 5 |
"attn_dropout_prob": 0.0,
|
| 6 |
"audio": {
|
| 7 |
"add_audio_bos_eos_token": true,
|
| 8 |
+
"audio_start_id": 155163,
|
| 9 |
"avg_pool": true,
|
| 10 |
"n_ctx": 1500,
|
| 11 |
"n_head": 20,
|
|
|
|
| 18 |
"AutoConfig": "configuration_qwen.QWenConfig",
|
| 19 |
"AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
|
| 20 |
},
|
| 21 |
+
"bf16": false,
|
| 22 |
"emb_dropout_prob": 0.0,
|
| 23 |
"fp16": false,
|
| 24 |
"fp32": false,
|
|
|
|
| 26 |
"initializer_range": 0.02,
|
| 27 |
"intermediate_size": 22016,
|
| 28 |
"kv_channels": 128,
|
| 29 |
+
"layer_norm_epsilon": 1e-06,
|
| 30 |
+
"max_position_embeddings": 2048,
|
| 31 |
"model_type": "qwen",
|
| 32 |
"no_bias": true,
|
| 33 |
"num_attention_heads": 32,
|
|
|
|
| 46 |
"use_cache_kernel": false,
|
| 47 |
"use_cache_quantization": false,
|
| 48 |
"use_dynamic_ntk": true,
|
| 49 |
+
"use_flash_attn": "auto",
|
| 50 |
"use_logn_attn": true,
|
| 51 |
"vocab_size": 155947
|
| 52 |
}
|
configuration.json
CHANGED
|
@@ -1 +1,3 @@
|
|
| 1 |
-
{"framework":"Pytorch",
|
|
|
|
|
|
|
|
|
| 1 |
+
{"framework":"Pytorch",
|
| 2 |
+
"task":"multimodal-dialogue",
|
| 3 |
+
"allow_remote": true}
|
modeling_qwen.py
CHANGED
|
@@ -1015,20 +1015,18 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
| 1015 |
self.lm_head.half()
|
| 1016 |
self.post_init()
|
| 1017 |
|
| 1018 |
-
|
| 1019 |
@classmethod
|
| 1020 |
def from_pretrained(
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
):
|
| 1028 |
if os.path.isdir(pretrained_model_name_or_path):
|
| 1029 |
# Local Directory of Models
|
| 1030 |
mel_filters_path = os.path.join(pretrained_model_name_or_path, 'mel_filters.npz')
|
| 1031 |
-
print(mel_filters_path)
|
| 1032 |
tgt_cache_path = os.path.join(os.path.dirname(__file__), 'mel_filters.npz')
|
| 1033 |
shutil.copy(mel_filters_path, tgt_cache_path)
|
| 1034 |
else:
|
|
@@ -1036,7 +1034,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
| 1036 |
from huggingface_hub import hf_hub_download
|
| 1037 |
hf_hub_download(repo_id=pretrained_model_name_or_path, filename="mel_filters.npz",
|
| 1038 |
token=kwargs.get('token', None), local_dir=os.path.dirname(__file__))
|
| 1039 |
-
return super().from_pretrained(pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir,
|
|
|
|
| 1040 |
|
| 1041 |
def get_output_embeddings(self):
|
| 1042 |
return self.lm_head
|
|
|
|
| 1015 |
self.lm_head.half()
|
| 1016 |
self.post_init()
|
| 1017 |
|
|
|
|
| 1018 |
@classmethod
|
| 1019 |
def from_pretrained(
|
| 1020 |
+
cls,
|
| 1021 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 1022 |
+
*model_args,
|
| 1023 |
+
config=None,
|
| 1024 |
+
cache_dir=None,
|
| 1025 |
+
**kwargs,
|
| 1026 |
):
|
| 1027 |
if os.path.isdir(pretrained_model_name_or_path):
|
| 1028 |
# Local Directory of Models
|
| 1029 |
mel_filters_path = os.path.join(pretrained_model_name_or_path, 'mel_filters.npz')
|
|
|
|
| 1030 |
tgt_cache_path = os.path.join(os.path.dirname(__file__), 'mel_filters.npz')
|
| 1031 |
shutil.copy(mel_filters_path, tgt_cache_path)
|
| 1032 |
else:
|
|
|
|
| 1034 |
from huggingface_hub import hf_hub_download
|
| 1035 |
hf_hub_download(repo_id=pretrained_model_name_or_path, filename="mel_filters.npz",
|
| 1036 |
token=kwargs.get('token', None), local_dir=os.path.dirname(__file__))
|
| 1037 |
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir,
|
| 1038 |
+
**kwargs)
|
| 1039 |
|
| 1040 |
def get_output_embeddings(self):
|
| 1041 |
return self.lm_head
|
qwen_generation_utils.py
CHANGED
|
@@ -186,7 +186,6 @@ def make_context(
|
|
| 186 |
+ nl_tokens
|
| 187 |
)
|
| 188 |
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
|
| 189 |
-
print(raw_text)
|
| 190 |
audio_info = tokenizer.process_audio(raw_text)
|
| 191 |
|
| 192 |
elif chat_format == "raw":
|
|
|
|
| 186 |
+ nl_tokens
|
| 187 |
)
|
| 188 |
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
|
|
|
|
| 189 |
audio_info = tokenizer.process_audio(raw_text)
|
| 190 |
|
| 191 |
elif chat_format == "raw":
|
tokenization_qwen.py
CHANGED
|
@@ -17,13 +17,11 @@ from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Opt
|
|
| 17 |
|
| 18 |
import tiktoken
|
| 19 |
import numpy as np
|
| 20 |
-
|
| 21 |
-
from PIL import ImageFont
|
| 22 |
-
from PIL import ImageDraw
|
| 23 |
from transformers import PreTrainedTokenizer, AddedToken
|
| 24 |
from transformers.utils import try_to_load_from_cache
|
| 25 |
-
from transformers.tokenization_utils_base import BatchEncoding,PaddingStrategy,TruncationStrategy,\
|
| 26 |
-
TextInput,TextInputPair,PreTokenizedInput,PreTokenizedInputPair,TensorType, EncodedInput, EncodedInputPair
|
| 27 |
|
| 28 |
import matplotlib.colors as mcolors
|
| 29 |
from matplotlib.font_manager import FontProperties
|
|
@@ -31,7 +29,6 @@ from .audio import *
|
|
| 31 |
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
-
|
| 35 |
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"}
|
| 36 |
|
| 37 |
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
|
@@ -43,11 +40,11 @@ IMEND = "<|im_end|>"
|
|
| 43 |
# as different as possible to minimize the impact
|
| 44 |
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
| 45 |
SPECIAL_TOKENS = (
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
) + EXTRAS
|
| 50 |
-
|
| 51 |
LANGUAGES = {
|
| 52 |
"en": "english",
|
| 53 |
"zh": "chinese",
|
|
@@ -68,23 +65,25 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
|
|
| 68 |
for token, rank in (line.split() for line in contents.splitlines() if line)
|
| 69 |
}
|
| 70 |
|
|
|
|
| 71 |
def _list_find(
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
):
|
| 76 |
for i in range(start, len(input_list)):
|
| 77 |
if input_list[i] in candidates:
|
| 78 |
return i
|
| 79 |
return -1
|
| 80 |
|
|
|
|
| 81 |
def _replace_closed_tag(
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
):
|
| 89 |
if isinstance(start_tags, (str, int)):
|
| 90 |
start_tags = (start_tags,)
|
|
@@ -99,107 +98,93 @@ def _replace_closed_tag(
|
|
| 99 |
start = _list_find(input_tokens, start_tags, end)
|
| 100 |
if start == -1:
|
| 101 |
break
|
| 102 |
-
output_tokens.extend(exclusive_replace_func(input_tokens[end
|
| 103 |
tag_idx = start_tags.index(input_tokens[start])
|
| 104 |
end = _list_find(input_tokens, (end_tags[tag_idx],), start)
|
| 105 |
if end == -1:
|
| 106 |
-
raise ValueError("Unclosed
|
| 107 |
-
output_tokens.extend(inclusive_replace_func(input_tokens[start
|
| 108 |
end += 1
|
| 109 |
audio_idx += 1
|
| 110 |
-
output_tokens.extend(exclusive_replace_func(input_tokens[end
|
| 111 |
return output_tokens
|
| 112 |
|
|
|
|
| 113 |
class QWenTokenizer(PreTrainedTokenizer):
|
| 114 |
"""QWen tokenizer."""
|
| 115 |
|
| 116 |
vocab_files_names = VOCAB_FILES_NAMES
|
| 117 |
|
| 118 |
def __init__(
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
):
|
| 126 |
super().__init__(**kwargs)
|
| 127 |
self.audio_start_tag = audio_start_tag
|
| 128 |
self.audio_end_tag = audio_end_tag
|
| 129 |
self.audio_pad_tag = "[[[AUDIO:modality]]]"
|
| 130 |
-
self.IMAGE_ST = ("<ref>", "</ref>", "<box>", "</box>", "<quad>", "</quad>")
|
| 131 |
|
| 132 |
self.AUDIO_ST = (
|
| 133 |
'[[[AUDIO:modality]]]',
|
| 134 |
-
|
| 135 |
-
"<|
|
| 136 |
-
#
|
|
|
|
| 137 |
"<|translate|>",
|
| 138 |
"<|transcribe|>",
|
| 139 |
"<|caption|>",
|
| 140 |
"<|keyword|>",
|
| 141 |
-
#
|
| 142 |
-
"<|unknown|>", #
|
| 143 |
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
| 144 |
-
"<|
|
| 145 |
-
#
|
| 146 |
"<|notimestamps|>",
|
| 147 |
"<|sil|>",
|
| 148 |
"<|timestamps|>",
|
| 149 |
-
*[f"<|{i * 0.01:.2f}|>" for i in range(3001)],
|
| 150 |
-
#
|
| 151 |
-
"<|caption_audiocaps|>", #
|
| 152 |
-
"<|caption_clotho|>", #
|
| 153 |
-
"<|audioset_ontology|>", #
|
| 154 |
-
"<|caption_plain|>", #
|
| 155 |
-
"<|itn|>", #
|
| 156 |
-
"<|wo_itn|>", #
|
| 157 |
-
# 特殊任务——实体识别
|
| 158 |
"<|startofentityvalue|>",
|
| 159 |
"<|endofentityvalue|>",
|
| 160 |
"<|startofentitytype|>",
|
| 161 |
"<|endofentitytype|>",
|
| 162 |
-
"<|named_entity_recognition|>",
|
| 163 |
-
|
| 164 |
-
"<|grounding|>",
|
| 165 |
"<|startofword|>",
|
| 166 |
"<|endofword|>",
|
| 167 |
-
"<|delim|>", #
|
| 168 |
-
#
|
| 169 |
-
"<|
|
| 170 |
-
#
|
| 171 |
-
"<|
|
| 172 |
-
#
|
| 173 |
-
"<|
|
| 174 |
-
"<|
|
| 175 |
-
|
| 176 |
-
"<|
|
| 177 |
-
|
| 178 |
-
"<|
|
| 179 |
-
"<|
|
| 180 |
-
#
|
| 181 |
-
"<|
|
| 182 |
-
#
|
| 183 |
-
"<|
|
| 184 |
-
#
|
| 185 |
-
"<|
|
| 186 |
-
"<|
|
| 187 |
-
"<|
|
| 188 |
-
#
|
| 189 |
-
"<|
|
| 190 |
-
# 子任务--event
|
| 191 |
-
"<|event|>",
|
| 192 |
-
# 子任务--vocal_classification
|
| 193 |
-
"<|vocal_classification|>",
|
| 194 |
-
# 特殊任务--SLU
|
| 195 |
-
"<|speech_understanding|>",
|
| 196 |
-
"<|scenario|>",
|
| 197 |
-
"<|action|>",
|
| 198 |
-
"<|entities|>",
|
| 199 |
-
# 子任务--语音编辑
|
| 200 |
-
"<|speech_edit|>",
|
| 201 |
-
# 子任务--命令
|
| 202 |
-
"<|speech_command|>",
|
| 203 |
audio_start_tag,
|
| 204 |
audio_end_tag
|
| 205 |
)
|
|
@@ -210,9 +195,8 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 210 |
self.special_tokens = {
|
| 211 |
token: index
|
| 212 |
for index, token in enumerate(
|
| 213 |
-
# SPECIAL_TOKENS + self.IMAGE_ST + self.AUDIO_ST, start=len(self.mergeable_ranks)
|
| 214 |
SPECIAL_TOKENS + self.AUDIO_ST, start=len(self.mergeable_ranks)
|
| 215 |
-
|
| 216 |
)
|
| 217 |
}
|
| 218 |
self.audio_start_id = self.special_tokens[self.audio_start_tag]
|
|
@@ -229,7 +213,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 229 |
special_tokens=self.special_tokens,
|
| 230 |
)
|
| 231 |
assert (
|
| 232 |
-
|
| 233 |
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
| 234 |
|
| 235 |
self.decoder = {
|
|
@@ -260,7 +244,6 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 260 |
)
|
| 261 |
self.tokenizer = enc
|
| 262 |
|
| 263 |
-
|
| 264 |
def __len__(self) -> int:
|
| 265 |
return self.tokenizer.n_vocab
|
| 266 |
|
|
@@ -268,7 +251,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 268 |
return self.mergeable_ranks
|
| 269 |
|
| 270 |
def convert_tokens_to_ids(
|
| 271 |
-
|
| 272 |
) -> List[int]:
|
| 273 |
ids = []
|
| 274 |
if isinstance(tokens, (str, bytes)):
|
|
@@ -288,7 +271,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 288 |
raise ValueError('Adding regular tokens is not supported')
|
| 289 |
for token in new_tokens:
|
| 290 |
surface_form = token.content if isinstance(token, AddedToken) else token
|
| 291 |
-
if surface_form not in SPECIAL_TOKENS
|
| 292 |
raise ValueError('Adding unknown special tokens is not supported')
|
| 293 |
return 0
|
| 294 |
|
|
@@ -307,12 +290,12 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 307 |
return (file_path,)
|
| 308 |
|
| 309 |
def tokenize(
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
) -> List[Union[bytes, str]]:
|
| 317 |
"""
|
| 318 |
Converts a string in a sequence of tokens.
|
|
@@ -338,44 +321,46 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 338 |
|
| 339 |
# this implementation takes a detour: text -> token id -> token surface forms
|
| 340 |
for t in self.tokenizer.encode(
|
| 341 |
-
|
| 342 |
):
|
| 343 |
tokens.append(self.decoder[t])
|
| 344 |
|
| 345 |
def _encode_audiourl(audio_tokens, audio_info, audio_idx):
|
| 346 |
assert audio_tokens[0] == self.audio_start_tag and audio_tokens[-1] == self.audio_end_tag
|
| 347 |
audio_token_span = audio_info['audio_span_tokens'][audio_idx]
|
| 348 |
-
out_audio_tokens = [self.audio_start_tag] + [self.audio_pad_tag]*(audio_token_span-2) + [
|
|
|
|
| 349 |
return out_audio_tokens
|
| 350 |
|
| 351 |
-
return _replace_closed_tag(tokens, self.audio_start_tag, self.audio_end_tag, _encode_audiourl,
|
|
|
|
| 352 |
|
| 353 |
def _batch_encode_plus(
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
) -> BatchEncoding:
|
| 380 |
|
| 381 |
def get_input_ids(text):
|
|
@@ -409,7 +394,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 409 |
for pair_id in range(len(batch_text_or_text_pairs)):
|
| 410 |
kwargs['audio_info'] = audio_info[pair_id]
|
| 411 |
ids_or_pair_ids = batch_text_or_text_pairs[pair_id]
|
| 412 |
-
|
| 413 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
| 414 |
ids, pair_ids = ids_or_pair_ids, None
|
| 415 |
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
|
|
@@ -488,23 +473,23 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 488 |
raise NotImplementedError
|
| 489 |
|
| 490 |
def _decode(
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
) -> str:
|
| 497 |
if isinstance(token_ids, int):
|
| 498 |
token_ids = [token_ids]
|
| 499 |
audio_info = kwargs.pop("audio_info", None)
|
| 500 |
|
| 501 |
-
|
| 502 |
def _decode_audiourl(audio_token_ids, audio_info, audio_idx):
|
| 503 |
assert audio_token_ids[0] == self.audio_start_id and audio_token_ids[-1] == self.audio_end_id
|
| 504 |
audio_url = audio_info["audio_urls"][audio_idx]
|
| 505 |
return [self.audio_start_id] + self.tokenizer.encode(audio_url) + [self.audio_end_id]
|
| 506 |
|
| 507 |
-
token_ids = _replace_closed_tag(token_ids, self.audio_start_id, self.audio_end_id, _decode_audiourl,
|
|
|
|
| 508 |
|
| 509 |
if skip_special_tokens:
|
| 510 |
token_ids = [i for i in token_ids if i < self.eod_id]
|
|
@@ -513,7 +498,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 513 |
def to_list_format(self, text: str):
|
| 514 |
text = unicodedata.normalize("NFC", text)
|
| 515 |
token_ids = self.tokenizer.encode(
|
| 516 |
-
text, allowed_special=set(self.
|
| 517 |
|
| 518 |
def _encode_audio_info(tokens):
|
| 519 |
if len(tokens) == 0:
|
|
@@ -561,10 +546,10 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 561 |
|
| 562 |
def process_audio(self, text):
|
| 563 |
audio_urls = self.extract_audio_urls(text)
|
| 564 |
-
if len(audio_urls)> 0:
|
| 565 |
audios, audio_lens, audio_span_tokens = [], [], []
|
| 566 |
for audio_path in audio_urls:
|
| 567 |
-
if audio_path.startswith("http://") or audio_path.startswith("https://"):
|
| 568 |
data = bytes(requests.get(audio_path, stream=True).content)
|
| 569 |
audio = load_bytesio_audio(data)
|
| 570 |
else:
|
|
@@ -578,7 +563,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
| 578 |
audio_len = [audio_len_after_cnn, audio_token_num]
|
| 579 |
audios.append(mel)
|
| 580 |
audio_lens.append(audio_len)
|
| 581 |
-
audio_span_tokens.append(audio_token_num+2)
|
| 582 |
input_audio_lengths = torch.IntTensor(audio_lens)
|
| 583 |
input_audios = torch.stack(audios, dim=0)
|
| 584 |
return {"input_audios": input_audios,
|
|
|
|
| 17 |
|
| 18 |
import tiktoken
|
| 19 |
import numpy as np
|
| 20 |
+
|
|
|
|
|
|
|
| 21 |
from transformers import PreTrainedTokenizer, AddedToken
|
| 22 |
from transformers.utils import try_to_load_from_cache
|
| 23 |
+
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TruncationStrategy, \
|
| 24 |
+
TextInput, TextInputPair, PreTokenizedInput, PreTokenizedInputPair, TensorType, EncodedInput, EncodedInputPair
|
| 25 |
|
| 26 |
import matplotlib.colors as mcolors
|
| 27 |
from matplotlib.font_manager import FontProperties
|
|
|
|
| 29 |
|
| 30 |
logger = logging.getLogger(__name__)
|
| 31 |
|
|
|
|
| 32 |
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"}
|
| 33 |
|
| 34 |
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
|
|
|
| 40 |
# as different as possible to minimize the impact
|
| 41 |
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
| 42 |
SPECIAL_TOKENS = (
|
| 43 |
+
ENDOFTEXT,
|
| 44 |
+
IMSTART,
|
| 45 |
+
IMEND,
|
| 46 |
+
) + EXTRAS
|
| 47 |
+
|
| 48 |
LANGUAGES = {
|
| 49 |
"en": "english",
|
| 50 |
"zh": "chinese",
|
|
|
|
| 65 |
for token, rank in (line.split() for line in contents.splitlines() if line)
|
| 66 |
}
|
| 67 |
|
| 68 |
+
|
| 69 |
def _list_find(
|
| 70 |
+
input_list: List[Any],
|
| 71 |
+
candidates: Tuple[Any],
|
| 72 |
+
start: int = 0,
|
| 73 |
):
|
| 74 |
for i in range(start, len(input_list)):
|
| 75 |
if input_list[i] in candidates:
|
| 76 |
return i
|
| 77 |
return -1
|
| 78 |
|
| 79 |
+
|
| 80 |
def _replace_closed_tag(
|
| 81 |
+
input_tokens: List[Any],
|
| 82 |
+
start_tags: Union[Any, Tuple[Any]],
|
| 83 |
+
end_tags: Union[Any, Tuple[Any]],
|
| 84 |
+
inclusive_replace_func: Callable,
|
| 85 |
+
exclusive_replace_func: Callable = lambda x: x,
|
| 86 |
+
audio_info: Dict = None
|
| 87 |
):
|
| 88 |
if isinstance(start_tags, (str, int)):
|
| 89 |
start_tags = (start_tags,)
|
|
|
|
| 98 |
start = _list_find(input_tokens, start_tags, end)
|
| 99 |
if start == -1:
|
| 100 |
break
|
| 101 |
+
output_tokens.extend(exclusive_replace_func(input_tokens[end: start]))
|
| 102 |
tag_idx = start_tags.index(input_tokens[start])
|
| 103 |
end = _list_find(input_tokens, (end_tags[tag_idx],), start)
|
| 104 |
if end == -1:
|
| 105 |
+
raise ValueError("Unclosed audio token")
|
| 106 |
+
output_tokens.extend(inclusive_replace_func(input_tokens[start: end + 1], audio_info, audio_idx))
|
| 107 |
end += 1
|
| 108 |
audio_idx += 1
|
| 109 |
+
output_tokens.extend(exclusive_replace_func(input_tokens[end:]))
|
| 110 |
return output_tokens
|
| 111 |
|
| 112 |
+
|
| 113 |
class QWenTokenizer(PreTrainedTokenizer):
|
| 114 |
"""QWen tokenizer."""
|
| 115 |
|
| 116 |
vocab_files_names = VOCAB_FILES_NAMES
|
| 117 |
|
| 118 |
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
vocab_file,
|
| 121 |
+
errors="replace",
|
| 122 |
+
audio_start_tag='<audio>',
|
| 123 |
+
audio_end_tag='</audio>',
|
| 124 |
+
**kwargs,
|
| 125 |
):
|
| 126 |
super().__init__(**kwargs)
|
| 127 |
self.audio_start_tag = audio_start_tag
|
| 128 |
self.audio_end_tag = audio_end_tag
|
| 129 |
self.audio_pad_tag = "[[[AUDIO:modality]]]"
|
|
|
|
| 130 |
|
| 131 |
self.AUDIO_ST = (
|
| 132 |
'[[[AUDIO:modality]]]',
|
| 133 |
+
# Transcription Tag
|
| 134 |
+
"<|startoftranscript|>", # Transcription
|
| 135 |
+
"<|startofanalysis|>", # Analysis
|
| 136 |
+
# Task Tag
|
| 137 |
"<|translate|>",
|
| 138 |
"<|transcribe|>",
|
| 139 |
"<|caption|>",
|
| 140 |
"<|keyword|>",
|
| 141 |
+
# Language Tag
|
| 142 |
+
"<|unknown|>", # unknown language
|
| 143 |
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
| 144 |
+
"<|zh_tr|>", # tranditional Chinese
|
| 145 |
+
# Timestamps Tag
|
| 146 |
"<|notimestamps|>",
|
| 147 |
"<|sil|>",
|
| 148 |
"<|timestamps|>",
|
| 149 |
+
*[f"<|{i * 0.01:.2f}|>" for i in range(3001)], # timestamps 0.00-30.00
|
| 150 |
+
# Output Instruction
|
| 151 |
+
"<|caption_audiocaps|>", # Audiocaps caption style
|
| 152 |
+
"<|caption_clotho|>", # Clotho caption style
|
| 153 |
+
"<|audioset_ontology|>", # Audioset ontology style
|
| 154 |
+
"<|caption_plain|>", # plain caption
|
| 155 |
+
"<|itn|>", # inversed text normalized
|
| 156 |
+
"<|wo_itn|>", # without inversed text normalized
|
|
|
|
| 157 |
"<|startofentityvalue|>",
|
| 158 |
"<|endofentityvalue|>",
|
| 159 |
"<|startofentitytype|>",
|
| 160 |
"<|endofentitytype|>",
|
| 161 |
+
"<|named_entity_recognition|>", # named entity recognition task
|
| 162 |
+
"<|audio_grounding|>",
|
|
|
|
| 163 |
"<|startofword|>",
|
| 164 |
"<|endofword|>",
|
| 165 |
+
"<|delim|>", # delimiter of timestamps pair in audio grounding
|
| 166 |
+
"<|emotion_recognition|>", # emotion recognition
|
| 167 |
+
"<|music_description|>", # music description
|
| 168 |
+
"<|note_analysis|>", # note analysis
|
| 169 |
+
"<|pitch|>", # note analysis: pitch
|
| 170 |
+
*[f"<|midi_pitch_{i}|>" for i in range(128)], # midi pitch 0-127
|
| 171 |
+
"<|velocity|>", # note analysis: velocity
|
| 172 |
+
*[f"<|midi_velocity_{i}|>" for i in range(128)], # midi velocity 0-127
|
| 173 |
+
"<|sonic|>", # note analysis: sonic
|
| 174 |
+
"<|instrument|>", # note analysis: instrument
|
| 175 |
+
"<|speaker_meta|>", # meta information of speaker
|
| 176 |
+
"<|song_meta|>", # meta information of song
|
| 177 |
+
"<|question|>", # AQA: question
|
| 178 |
+
"<|answer|>", # AQA: answer
|
| 179 |
+
"<|choice|>", # AQA: answer choice
|
| 180 |
+
"<|scene|>", # scene recognition
|
| 181 |
+
"<|event|>", # sound event
|
| 182 |
+
"<|vocal_classification|>", # vocal classification
|
| 183 |
+
"<|speech_understanding|>", # speech language understanding
|
| 184 |
+
"<|scenario|>", # speech language understanding: scenario
|
| 185 |
+
"<|action|>", # speech language understanding: action
|
| 186 |
+
"<|entities|>", # speech language understanding: entities
|
| 187 |
+
"<|speech_edit|>", # speech edit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
audio_start_tag,
|
| 189 |
audio_end_tag
|
| 190 |
)
|
|
|
|
| 195 |
self.special_tokens = {
|
| 196 |
token: index
|
| 197 |
for index, token in enumerate(
|
|
|
|
| 198 |
SPECIAL_TOKENS + self.AUDIO_ST, start=len(self.mergeable_ranks)
|
| 199 |
+
|
| 200 |
)
|
| 201 |
}
|
| 202 |
self.audio_start_id = self.special_tokens[self.audio_start_tag]
|
|
|
|
| 213 |
special_tokens=self.special_tokens,
|
| 214 |
)
|
| 215 |
assert (
|
| 216 |
+
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
|
| 217 |
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
| 218 |
|
| 219 |
self.decoder = {
|
|
|
|
| 244 |
)
|
| 245 |
self.tokenizer = enc
|
| 246 |
|
|
|
|
| 247 |
def __len__(self) -> int:
|
| 248 |
return self.tokenizer.n_vocab
|
| 249 |
|
|
|
|
| 251 |
return self.mergeable_ranks
|
| 252 |
|
| 253 |
def convert_tokens_to_ids(
|
| 254 |
+
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
|
| 255 |
) -> List[int]:
|
| 256 |
ids = []
|
| 257 |
if isinstance(tokens, (str, bytes)):
|
|
|
|
| 271 |
raise ValueError('Adding regular tokens is not supported')
|
| 272 |
for token in new_tokens:
|
| 273 |
surface_form = token.content if isinstance(token, AddedToken) else token
|
| 274 |
+
if surface_form not in SPECIAL_TOKENS + self.AUDIO_ST:
|
| 275 |
raise ValueError('Adding unknown special tokens is not supported')
|
| 276 |
return 0
|
| 277 |
|
|
|
|
| 290 |
return (file_path,)
|
| 291 |
|
| 292 |
def tokenize(
|
| 293 |
+
self,
|
| 294 |
+
text: str,
|
| 295 |
+
allowed_special: Union[Set, str] = "all",
|
| 296 |
+
disallowed_special: Union[Collection, str] = (),
|
| 297 |
+
audio_info: Dict = None,
|
| 298 |
+
**kwargs,
|
| 299 |
) -> List[Union[bytes, str]]:
|
| 300 |
"""
|
| 301 |
Converts a string in a sequence of tokens.
|
|
|
|
| 321 |
|
| 322 |
# this implementation takes a detour: text -> token id -> token surface forms
|
| 323 |
for t in self.tokenizer.encode(
|
| 324 |
+
text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
| 325 |
):
|
| 326 |
tokens.append(self.decoder[t])
|
| 327 |
|
| 328 |
def _encode_audiourl(audio_tokens, audio_info, audio_idx):
|
| 329 |
assert audio_tokens[0] == self.audio_start_tag and audio_tokens[-1] == self.audio_end_tag
|
| 330 |
audio_token_span = audio_info['audio_span_tokens'][audio_idx]
|
| 331 |
+
out_audio_tokens = [self.audio_start_tag] + [self.audio_pad_tag] * (audio_token_span - 2) + [
|
| 332 |
+
self.audio_end_tag]
|
| 333 |
return out_audio_tokens
|
| 334 |
|
| 335 |
+
return _replace_closed_tag(tokens, self.audio_start_tag, self.audio_end_tag, _encode_audiourl,
|
| 336 |
+
audio_info=audio_info)
|
| 337 |
|
| 338 |
def _batch_encode_plus(
|
| 339 |
+
self,
|
| 340 |
+
batch_text_or_text_pairs: Union[
|
| 341 |
+
List[TextInput],
|
| 342 |
+
List[TextInputPair],
|
| 343 |
+
List[PreTokenizedInput],
|
| 344 |
+
List[PreTokenizedInputPair],
|
| 345 |
+
List[EncodedInput],
|
| 346 |
+
List[EncodedInputPair],
|
| 347 |
+
],
|
| 348 |
+
add_special_tokens: bool = True,
|
| 349 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 350 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 351 |
+
max_length: Optional[int] = None,
|
| 352 |
+
stride: int = 0,
|
| 353 |
+
is_split_into_words: bool = False,
|
| 354 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 355 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 356 |
+
return_token_type_ids: Optional[bool] = None,
|
| 357 |
+
return_attention_mask: Optional[bool] = None,
|
| 358 |
+
return_overflowing_tokens: bool = False,
|
| 359 |
+
return_special_tokens_mask: bool = False,
|
| 360 |
+
return_offsets_mapping: bool = False,
|
| 361 |
+
return_length: bool = False,
|
| 362 |
+
verbose: bool = True,
|
| 363 |
+
**kwargs,
|
| 364 |
) -> BatchEncoding:
|
| 365 |
|
| 366 |
def get_input_ids(text):
|
|
|
|
| 394 |
for pair_id in range(len(batch_text_or_text_pairs)):
|
| 395 |
kwargs['audio_info'] = audio_info[pair_id]
|
| 396 |
ids_or_pair_ids = batch_text_or_text_pairs[pair_id]
|
| 397 |
+
# for ids_or_pair_ids in batch_text_or_text_pairs:
|
| 398 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
| 399 |
ids, pair_ids = ids_or_pair_ids, None
|
| 400 |
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
|
|
|
|
| 473 |
raise NotImplementedError
|
| 474 |
|
| 475 |
def _decode(
|
| 476 |
+
self,
|
| 477 |
+
token_ids: Union[int, List[int]],
|
| 478 |
+
skip_special_tokens: bool = False,
|
| 479 |
+
errors: str = None,
|
| 480 |
+
**kwargs,
|
| 481 |
) -> str:
|
| 482 |
if isinstance(token_ids, int):
|
| 483 |
token_ids = [token_ids]
|
| 484 |
audio_info = kwargs.pop("audio_info", None)
|
| 485 |
|
|
|
|
| 486 |
def _decode_audiourl(audio_token_ids, audio_info, audio_idx):
|
| 487 |
assert audio_token_ids[0] == self.audio_start_id and audio_token_ids[-1] == self.audio_end_id
|
| 488 |
audio_url = audio_info["audio_urls"][audio_idx]
|
| 489 |
return [self.audio_start_id] + self.tokenizer.encode(audio_url) + [self.audio_end_id]
|
| 490 |
|
| 491 |
+
token_ids = _replace_closed_tag(token_ids, self.audio_start_id, self.audio_end_id, _decode_audiourl,
|
| 492 |
+
audio_info=audio_info)
|
| 493 |
|
| 494 |
if skip_special_tokens:
|
| 495 |
token_ids = [i for i in token_ids if i < self.eod_id]
|
|
|
|
| 498 |
def to_list_format(self, text: str):
|
| 499 |
text = unicodedata.normalize("NFC", text)
|
| 500 |
token_ids = self.tokenizer.encode(
|
| 501 |
+
text, allowed_special=set(self.AUDIO_ST + (ENDOFTEXT,)))
|
| 502 |
|
| 503 |
def _encode_audio_info(tokens):
|
| 504 |
if len(tokens) == 0:
|
|
|
|
| 546 |
|
| 547 |
def process_audio(self, text):
|
| 548 |
audio_urls = self.extract_audio_urls(text)
|
| 549 |
+
if len(audio_urls) > 0:
|
| 550 |
audios, audio_lens, audio_span_tokens = [], [], []
|
| 551 |
for audio_path in audio_urls:
|
| 552 |
+
if audio_path.startswith("http://") or audio_path.startswith("https://"): # http
|
| 553 |
data = bytes(requests.get(audio_path, stream=True).content)
|
| 554 |
audio = load_bytesio_audio(data)
|
| 555 |
else:
|
|
|
|
| 563 |
audio_len = [audio_len_after_cnn, audio_token_num]
|
| 564 |
audios.append(mel)
|
| 565 |
audio_lens.append(audio_len)
|
| 566 |
+
audio_span_tokens.append(audio_token_num + 2) # add audio bos eos
|
| 567 |
input_audio_lengths = torch.IntTensor(audio_lens)
|
| 568 |
input_audios = torch.stack(audios, dim=0)
|
| 569 |
return {"input_audios": input_audios,
|
tokenizer_config.json
CHANGED
|
@@ -6,6 +6,6 @@
|
|
| 6 |
]
|
| 7 |
},
|
| 8 |
"clean_up_tokenization_spaces": true,
|
| 9 |
-
"model_max_length":
|
| 10 |
"tokenizer_class": "QWenTokenizer"
|
| 11 |
}
|
|
|
|
| 6 |
]
|
| 7 |
},
|
| 8 |
"clean_up_tokenization_spaces": true,
|
| 9 |
+
"model_max_length": 2048,
|
| 10 |
"tokenizer_class": "QWenTokenizer"
|
| 11 |
}
|