tim1900 commited on
Commit
18d98e3
·
verified ·
1 Parent(s): 75f7e3e

Upload 7 files

Browse files
README.md CHANGED
@@ -1,3 +1,55 @@
1
  ---
2
- license: mit
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
+ - zh
5
+ pipeline_tag: token-classification
6
  ---
7
+ # bert-chunker-chinese
8
+
9
+ ## Introduction
10
+
11
+ bert-chunker-chinese is a chinese text chunker based on BERT with a classifier head to predict the start token of chunks (for use in RAG, etc), and using a sliding window it cuts documents of any size into chunks. It was finetuned on top of [bge-small-zh-v1.5](https://huggingface.co/BAAI/bge-small-zh-v1.5).
12
+
13
+ This repo includes model checkpoint, BertChunker class definition file and all the other files needed.
14
+
15
+ ## Quickstart
16
+ Download this repository. Then enter it. Run the following:
17
+
18
+ ```python
19
+ # -*- coding: utf-8 -*-
20
+ import safetensors
21
+ from transformers import AutoConfig,AutoTokenizer
22
+ from modeling_bertchunke_zh import BertChunker
23
+
24
+ # load config and tokenizer
25
+ config = AutoConfig.from_pretrained(
26
+ "tim1900/bert-chunker-chinese",
27
+ trust_remote_code=True,
28
+ )
29
+ tokenizer = AutoTokenizer.from_pretrained(
30
+ "tim1900/bert-chunker-chinese",
31
+ padding_side="right",
32
+ model_max_length=config.max_position_embeddings,
33
+ trust_remote_code=True,
34
+ )
35
+
36
+ # initialize model
37
+ model = BertChunker(config)
38
+ device='cpu' # or 'cuda'
39
+ model.to(device)
40
+
41
+ # load tim1900/bert-chunker-chinese/model.safetensors
42
+ state_dict = safetensors.torch.load_file(f"./model.safetensors")
43
+ model.load_state_dict(state_dict)
44
+
45
+ # text to be chunked
46
+ text='''起点中文网(www.qidian.com)创立于2002年5月,是国内知名的原创文学网站,隶属于阅文集团旗下。起点中文网以推动中国原创文学事业为宗旨,长期致力于原创文学作者的挖掘与培养,并取得了巨大成果:2003年10月,起点中文网开启“在线收费阅读”服务,成为真正意义上的网络文学赢利模式的先锋之一,就此奠定了原创文学的行业基础。此后,起点又推出了作家福利、文学交互、内容发掘推广、版权管理等机制和体系,为原创文学的发展注入了巨大活力,有力推动了中国文学原创事业的发展。在清晨的微光中,一只孤独的猫头鹰在古老的橡树上低声吟唱,它的歌声如同夜色的回声,穿越了时间的迷雾。树叶在微风中轻轻摇曳,仿佛在诉说着古老的故事,每一个音符都带着森林的秘密。一位年轻的程序员正专注地敲打着键盘,代码的海洋在他眼前展开。他的手指在键盘上飞舞,如同钢琴家在演奏一曲复杂的交响乐。屏幕上的光标闪烁,仿佛在等待着下一个指令,引领他进入未知的数字世界。'''
47
+
48
+ # chunk the text. The lower threshold is, the more chunks will be generated. Can be negative or positive.
49
+ chunks=model.chunk_text(text, tokenizer, threshold=0.5)
50
+
51
+ # print chunks
52
+ for i, c in enumerate(chunks):
53
+ print(f'-----chunk: {i}------------')
54
+ print(c)
55
+ ```
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/data/bge-small-zh-v1.5",
3
+ "architectures": [
4
+ "BertChunker"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 512,
12
+ "id2label": {
13
+ "0": "LABEL_0"
14
+ },
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 2048,
17
+ "label2id": {
18
+ "LABEL_0": 0
19
+ },
20
+ "layer_norm_eps": 1e-12,
21
+ "max_position_embeddings": 512,
22
+ "model_type": "bert",
23
+ "num_attention_heads": 8,
24
+ "num_hidden_layers": 4,
25
+ "pad_token_id": 0,
26
+ "position_embedding_type": "absolute",
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.46.3",
29
+ "type_vocab_size": 2,
30
+ "use_cache": true,
31
+ "vocab_size": 21128
32
+ }
modeling_bertchunke_zh.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.modeling_utils import PreTrainedModel
2
+ from torch import nn
3
+ from transformers.models.bert.configuration_bert import BertConfig
4
+ from transformers.models.bert.modeling_bert import BertModel
5
+ import torch
6
+ import torch.nn.functional as F
7
+ class BertChunker(PreTrainedModel):
8
+
9
+ config_class = BertConfig
10
+
11
+ def __init__(self, config, ):
12
+ super().__init__(config)
13
+
14
+ self.model = BertModel(config)
15
+ self.chunklayer = nn.Linear(config.hidden_size, 2)
16
+
17
+ def forward(self, input_ids=None, attention_mask=None,labels=None, **kwargs):
18
+ model_output = self.model(
19
+ input_ids=input_ids, attention_mask=attention_mask, **kwargs
20
+ )
21
+ token_embeddings = model_output[0]
22
+ logits = self.chunklayer(token_embeddings)
23
+ model_output["logits"]=logits
24
+ loss = None
25
+ logits = logits.contiguous()
26
+ if labels!=None:
27
+ labels = labels.contiguous()
28
+ # Flatten the tokens
29
+ loss_fct = nn.CrossEntropyLoss()#用-100
30
+ # loss_fct = nn.CrossEntropyLoss(ignore_index=50257)
31
+ logits = logits.view(-1, logits.shape[-1])
32
+ labels = labels.view(-1)
33
+ # Enable model parallelism
34
+ labels = labels.to(labels.device)
35
+ loss = loss_fct(logits, labels)
36
+ model_output["loss"]=loss
37
+
38
+ return model_output
39
+
40
+ def chunk_text(self, text:str, tokenizer,threshold=0.5)->list[str]:
41
+ # slide context window
42
+ MAX_TOKENS=self.model.config.max_position_embeddings
43
+ tokens=tokenizer(text, return_tensors="pt",truncation=False)
44
+ input_ids=tokens['input_ids'].to(self.device)
45
+ attention_mask=tokens['attention_mask'][:,0:MAX_TOKENS]
46
+ attention_mask=attention_mask.to(self.device)
47
+ CLS=input_ids[:,0].unsqueeze(0)
48
+ SEP=input_ids[:,-1].unsqueeze(0)
49
+ input_ids=input_ids[:,1:-1]
50
+ self.eval()
51
+ split_str_poses=[]
52
+
53
+ windows_start =0
54
+ windows_end= 0
55
+
56
+ while windows_end <= input_ids.shape[1]:
57
+ windows_end= windows_start + MAX_TOKENS-2
58
+
59
+ ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
60
+
61
+ ids=ids.to(self.device)
62
+
63
+ output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1],device=self.device))
64
+ logits = output['logits'][:, 1:-1,:]
65
+ chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1]
66
+ chunk_decision = (chunk_probabilities>threshold)
67
+ greater_rows_indices = torch.where(chunk_decision)[1].tolist()
68
+
69
+ # null or not
70
+ if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
71
+
72
+ split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices]
73
+
74
+ split_str_poses += split_str_pos
75
+
76
+ windows_start = greater_rows_indices[-1] + windows_start
77
+
78
+ else:
79
+
80
+ windows_start = windows_end
81
+
82
+ substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
83
+ return substrings
84
+
85
+
86
+ def chunk_text_smooth(self, text:str, tokenizer,threshold=0)->list[str]:
87
+ # slide context window
88
+ MAX_TOKENS=self.model.config.max_position_embeddings
89
+ tokens=tokenizer(text, return_tensors="pt",truncation=False)
90
+ input_ids=tokens['input_ids'].to(self.device)
91
+ attention_mask=tokens['attention_mask'][:,0:MAX_TOKENS]
92
+ attention_mask=attention_mask.to(self.device)
93
+ CLS=input_ids[:,0].unsqueeze(0)
94
+ SEP=input_ids[:,-1].unsqueeze(0)
95
+ input_ids=input_ids[:,1:-1]
96
+ self.eval()
97
+ split_str_poses=[]
98
+
99
+ windows_start =0
100
+ windows_end= 0
101
+ prob_pair_list=[]
102
+
103
+ for j in range(input_ids.shape[1]):
104
+
105
+ prob_pair_list.append([])
106
+
107
+
108
+ while windows_start <= input_ids.shape[1]:
109
+ windows_end= windows_start + MAX_TOKENS-2
110
+
111
+ ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
112
+
113
+ ids=ids.to(self.device)
114
+
115
+ output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1],device=self.device))
116
+ logits = output['logits'][:, 1:-1,:]
117
+
118
+
119
+ chunk_probabilities = F.softmax(logits, dim=-1).tolist()
120
+
121
+
122
+ # is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
123
+
124
+
125
+ for i in range(windows_start, windows_start + len(chunk_probabilities[0])):
126
+ prob_pair_list[i].append(chunk_probabilities[0][i-windows_start][1])
127
+
128
+
129
+ # split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices]
130
+
131
+ # split_str_poses += split_str_pos
132
+
133
+ windows_start = windows_start + MAX_TOKENS//2-1
134
+
135
+ split_str_poses=[]
136
+ for i in range(len(prob_pair_list)):
137
+ if sum(prob_pair_list[i])/len(prob_pair_list[i])>threshold:
138
+ split_str_poses+=[tokens.token_to_chars(i + 1).start]
139
+
140
+
141
+
142
+ substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
143
+ return substrings
144
+
145
+
146
+
147
+ def chunk_text_fast(
148
+ self, text: str, tokenizer, batchsize=20, threshold=0
149
+ ) -> list[str]:
150
+ # chunk the text faster with a fixed context window, batchsize is the number of windows run per batch.
151
+ self.eval()
152
+
153
+ split_str_poses=[]
154
+ MAX_TOKENS = self.model.config.max_position_embeddings
155
+ USEFUL_TOKENS = MAX_TOKENS - 2 # delete cls and sep
156
+ tokens = tokenizer(text, return_tensors="pt", truncation=False)
157
+ input_ids = tokens["input_ids"]
158
+
159
+
160
+ CLS = tokenizer.cls_token_id
161
+
162
+ SEP = tokenizer.sep_token_id
163
+
164
+ input_ids = input_ids[:, 1:-1].squeeze().contiguous()# delete cls and sep
165
+
166
+ token_num = input_ids.shape[0]
167
+ seq_num = input_ids.shape[0] // (USEFUL_TOKENS)
168
+ left_token_num = input_ids.shape[0] % (USEFUL_TOKENS)
169
+
170
+ if seq_num > 0:
171
+
172
+ reshaped_input_ids = input_ids[: seq_num * USEFUL_TOKENS].view( seq_num, USEFUL_TOKENS )
173
+
174
+ i = torch.arange(seq_num).unsqueeze(1)
175
+ j = torch.arange(USEFUL_TOKENS).repeat(seq_num, 1)
176
+
177
+ bias = 1 # 1 bias by cls token
178
+ position_id = i * (USEFUL_TOKENS) + j + bias
179
+ position_id = position_id.to(self.device)
180
+ reshaped_input_ids = torch.cat(
181
+ (
182
+ torch.full((reshaped_input_ids.shape[0], 1), CLS),
183
+ reshaped_input_ids,
184
+ torch.full((reshaped_input_ids.shape[0], 1), SEP),
185
+ ),
186
+ 1,
187
+ )
188
+
189
+ batch_num = seq_num // batchsize
190
+ left_seq_num = seq_num % batchsize
191
+ for i in range(batch_num):
192
+ batch_input = reshaped_input_ids[i : i + batchsize, :].to(self.device)
193
+ attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
194
+ output = self(input_ids=batch_input, attention_mask=attention_mask)
195
+ logits = output['logits'][:, 1:-1,:]#delete cls and sep
196
+ is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
197
+ pos = is_left_greater * position_id[i : i + batchsize, :]
198
+ pos = pos[pos>0].tolist()
199
+ split_str_poses += [tokens.token_to_chars(p).start for p in pos]
200
+ if left_seq_num > 0:
201
+ batch_input = reshaped_input_ids[-left_seq_num:, :].to(self.device)
202
+ attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
203
+ output = self(input_ids=batch_input, attention_mask=attention_mask)
204
+ logits = output['logits'][:, 1:-1,:]#delete cls and sep
205
+ is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
206
+ pos = is_left_greater * position_id[-left_seq_num:, :]
207
+ pos = pos[pos>0].tolist()
208
+ split_str_poses += [tokens.token_to_chars(p).start for p in pos]
209
+
210
+ if left_token_num > 0:
211
+ left_input_ids = torch.cat([torch.tensor([CLS]), input_ids[-left_token_num:], torch.tensor([SEP])])
212
+ left_input_ids = left_input_ids.unsqueeze(0).to(self.device)
213
+ attention_mask = torch.ones(left_input_ids.shape[0], left_input_ids.shape[1]).to(self.device)
214
+ output = self(input_ids=left_input_ids, attention_mask=attention_mask)
215
+ logits = output['logits'][:, 1:-1,:]#delete cls and sep
216
+ is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
217
+ bias = token_num - (left_input_ids.shape[1] - 2) + 1
218
+ pos = (torch.where(is_left_greater)[1] + bias).tolist()
219
+ split_str_poses += [tokens.token_to_chars(p).start for p in pos]
220
+
221
+ substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
222
+ return substrings
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": false,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "never_split": null,
51
+ "pad_token": "[PAD]",
52
+ "padding_side": "right",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff