alever_sn commited on
Commit
bb727b8
·
1 Parent(s): 85b3f16
README.md CHANGED
@@ -1,3 +1,349 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Improving Low-Resource Sequence Labeling with Knowledge Fusion and Contextual Label Explanations
2
+
3
+ ![workflow](docs/assets/workflow.png)
4
+
5
+ ## 🌍 Overview
6
+
7
+ This repository provides the official implementation of our paper:
8
+
9
+ > **Improving Low-Resource Sequence Labeling with Knowledge Fusion and Contextual Label Explanations**
10
+ > [arXiv:2501.19093](https://arxiv.org/abs/2501.19093)
11
+
12
+ Low-resource sequence labeling often suffers from data sparsity and limited contextual generalization.
13
+ We propose **KnowFREE (Knowledge-Fused Representation Enhancement Framework)** — a framework that integrates **external linguistic knowledge** and **contextual label explanations** into the model’s representation space to enhance low-resource performance.
14
+
15
+ **Key Highlights:**
16
+
17
+ Combining an **LLM-based knowledge enhancement workflow** with a **span-based KnowFREE model** to effectively address these challenges.
18
+
19
+ **Pipeline 1: Label Extension Annotation**
20
+ * Objective: To leverage LLMs to generate extension entity labels, word segmentation tags, and POS tags for the original samples.
21
+ * Effect:
22
+ * Enhances the model's understanding of fine-grained contextual semantics.
23
+ * Improves the ability to distinguish entity boundaries in character-dense languages.
24
+
25
+ **Pipeline 2: Enriched Explanation Synthesis**
26
+
27
+ * Objective: Using LLMs to generate detailed, context-aware explanations for target entities, thereby synthesizing new, high-quality training samples.
28
+ * Effect:
29
+ * Effectively mitigates semantic distribution bias between synthetic samples and the target domain.
30
+ * Significantly expands the number of samples and improves model performance in extremely low-resource settings.
31
+
32
+
33
+
34
+ ---
35
+
36
+ ## 🔗 Quick Links
37
+
38
+ - [Model Checkpoints](#♠️-model-checkpoints)
39
+ - [Data Augmentation Workflow](#📊-data-augmentation-workflow)
40
+ - [Train KnowFREE](#🔥-run-knowfree-models)
41
+ - [Citation](#📚-citation)
42
+
43
+ ## ♠️ Model Checkpoints
44
+
45
+ Due to the large number of experiments, the architectural differences between the initial and reconstructed models, and the limited practical value of low-resource checkpoints sampled from the full dataset, we only release a few representative checkpoints (e.g., weibo) on Hugging Face for reference, as shown below:
46
+
47
+ | Model | F1 |
48
+ | :------------------------------------------------------------------------------------------------------------------- | :---: |
49
+ | [aleversn/KnowFREE-Weibo-BERT-base (Many shots 1000 with ChatGLM3)](https://huggingface.co/aleversn/GCSE-BERT-base) | 76.78 |
50
+ | [aleversn/KnowFREE-Youku-BERT-base (Many shots 1000 with ChatGLM3)](https://huggingface.co/aleversn/GCSE-BERT-large) | 84.50 |
51
+
52
+ ---
53
+
54
+ ## 🧩 KnowFREE Framework
55
+
56
+ ![KnowFREE](docs/assets/knowfree.png)
57
+
58
+ **Architecture**: A Biaffine-based span model that supports **nested entity** annotation.
59
+
60
+ **Core Innovations:**
61
+
62
+ * Introduces a **Local Multi-head Attention Layer** to efficiently fuse the multi-type extension label features generated in Pipeline 1.
63
+ * **No External Knowledge Needed for Inference:** The model learns to fuse knowledge during the training, the logits of extension labels will be masked during inference.
64
+
65
+ ---
66
+
67
+ ## ⚙️ Installation Guide
68
+
69
+ ### Core Dependencies
70
+
71
+ Create an environment and install dependencies:
72
+
73
+ ```bash
74
+ conda create -n knowfree python=3.8
75
+ conda activate knowfree
76
+ ```
77
+
78
+ ```bash
79
+ pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
80
+ pip install transformers==4.18.0 fastNLP==1.0.1 PrettyTable
81
+ pip install torch-scatter==2.0.8 -f https://data.pyg.org/whl/torch-1.8.0+cu111.html
82
+ ```
83
+
84
+ ## 📊 Data Augmentation Workflow
85
+
86
+ See the detailed data synthesis pipeline in [Syn_Pipelines](docs/Syn_Pipelines.md).
87
+
88
+ In KnowFREE, we employ **contextual paraphrasing and label explanation synthesis** to augment low-resource datasets.
89
+ For each entity label, LLMs generate descriptive explanations that are integrated into the learning process to mitigate label semantic sparsity.
90
+
91
+ ---
92
+
93
+ ## 🔥 Run KnowFREE Models
94
+
95
+ ### Training with `KnowFREE`
96
+
97
+ #### Dataset Format
98
+
99
+ Specify the dataset path using the `data_present_path` argument (`Default`: `./datasets/present.json`). The file should be a JSON object with the following format:
100
+
101
+ ```json
102
+ {
103
+ "weibo": {
104
+ "train": "./datasets/weibo/train.jsonl",
105
+ "dev": "./datasets/weibo/dev.jsonl",
106
+ "test": "./datasets/weibo/test.jsonl",
107
+ "labels": "./datasets/weibo/labels.txt"
108
+ }
109
+ }
110
+ ```
111
+
112
+ **Train Samples of Different Languages:**
113
+
114
+ - Chinese
115
+
116
+ ```jsonl
117
+ {"text": ["科", "技", "全", "方", "位", "资", "讯", "智", "能", ",", "快", "捷", "的", "汽", "车", "生", "活", "需", "要", "有", "三", "屏", "一", "云", "爱", "你"], "label": ["O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"], "entities": []}
118
+ {"text": ["对", ",", "输", "给", "一", "个", "女", "人", ",", "的", "成", "绩", "。", "失", "望"], "label": ["O", "O", "O", "O", "O", "O", "B-PER.NOM", "E-PER.NOM", "O", "O", "O", "O", "O", "O", "O"], "entities": [{"start": 6, "entity": "PER.NOM", "end": 8, "text": ["女", "人"]}]}
119
+ {"text": ["今", "天", "下", "午", "起", "来", "看", "到", "外", "面", "的", "太", "阳", "。", "。", "。", "。", "我", "第", "一", "反", "应", "竟", "然", "是", "强", "烈", "的", "想", "回", "家", "泪", "想", "我", "们", "一", "起", "在", "嘉", "鱼", "个", "时", "候", "了", "。", "。", "。", "。", "有", "好", "多", "好", "多", "的", "话", "想", "对", "你", "说", "李", "巾", "凡", "想", "要", "瘦", "瘦", "瘦", "成", "李", "帆", "我", "是", "想", "切", "开", "云", "朵", "的", "心"], "label": ["O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "B-LOC.NAM", "E-LOC.NAM", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "B-PER.NAM", "I-PER.NAM", "E-PER.NAM", "O", "O", "O", "O", "O", "O", "B-PER.NAM", "E-PER.NAM", "O", "O", "O", "O", "O", "O", "O", "O", "O"], "entities": [{"start": 38, "entity": "LOC.NAM", "end": 40, "text": ["嘉", "鱼"]}, {"start": 59, "entity": "PER.NAM", "end": 62, "text": ["李", "巾", "凡"]}, {"start": 68, "entity": "PER.NAM", "end": 70, "text": ["李", "帆"]}]}
120
+ ```
121
+
122
+ - English
123
+
124
+ ```jsonl
125
+ {"text": ["im", "thinking", "of", "a", "comedy", "where", "a", "group", "of", "husbands", "receive", "one", "chance", "from", "their", "wives", "to", "engage", "with", "other", "women"], "entities": [{"start": 4, "end": 5, "entity": "GENRE", "text": ["comedy"]}, {"start": 6, "end": 21, "entity": "PLOT", "text": ["a", "group", "of", "husbands", "receive", "one", "chance", "from", "their", "wives", "to", "engage", "with", "other", "women"]}]}
126
+ {"text": ["another", "sequel", "of", "an", "action", "movie", "about", "drag", "street", "car", "races", "alcohol", "and", "gun", "violence"], "entities": [{"start": 1, "end": 2, "entity": "RELATIONSHIP", "text": ["sequel"]}, {"start": 4, "end": 5, "entity": "GENRE", "text": ["action"]}, {"start": 7, "end": 15, "entity": "PLOT", "text": ["drag", "street", "car", "races", "alcohol", "and", "gun", "violence"]}]}
127
+ {"text": ["what", "is", "the", "name", "of", "the", "movie", "in", "which", "a", "group", "of", "criminals", "begin", "to", "suspect", "that", "one", "of", "them", "is", "a", "police", "informant", "after", "a", "simple", "jewelery", "heist", "goes", "terribly", "wrong"], "entities": [{"start": 9, "end": 32, "entity": "PLOT", "text": ["a", "group", "of", "criminals", "begin", "to", "suspect", "that", "one", "of", "them", "is", "a", "police", "informant", "after", "a", "simple", "jewelery", "heist", "goes", "terribly", "wrong"]}]}
128
+ {"text": ["a", "movie", "with", "vin", "diesel", "in", "world", "war", "2", "in", "a", "foreign", "country", "shooting", "people"], "entities": [{"start": 3, "end": 5, "entity": "ACTOR", "text": ["vin", "diesel"]}, {"start": 6, "end": 9, "entity": "GENRE", "text": ["world", "war", "2"]}, {"start": 11, "end": 15, "entity": "PLOT", "text": ["foreign", "country", "shooting", "people"]}]}
129
+ {"text": ["what", "is", "the", "1991", "disney", "animated", "movie", "that", "featured", "angela", "lansbury", "as", "the", "voice", "of", "a", "teapot"], "entities": [{"start": 3, "end": 4, "entity": "YEAR", "text": ["1991"]}, {"start": 5, "end": 6, "entity": "GENRE", "text": ["animated"]}, {"start": 9, "end": 11, "entity": "ACTOR", "text": ["angela", "lansbury"]}, {"start": 16, "end": 17, "entity": "CHARACTER_NAME", "text": ["teapot"]}]}
130
+ ```
131
+
132
+ - Japanese
133
+
134
+ ```jsonl
135
+ {"text": ["I", "n", "f", "o", "r", "m", "i", "x", "の", "動", "き", "を", "み", "て", "、", "オ", "ラ", "ク", "ル", "と", "I", "B", "M", "も", "追", "随", "し", "た", "。"], "entities": [{"start": 0, "end": 8, "entity": "法人名", "text": ["I", "n", "f", "o", "r", "m", "i", "x"]}, {"start": 15, "end": 19, "entity": "法人名", "text": ["オ", "ラ", "ク", "ル"]}, {"start": 20, "end": 23, "entity": "法人名", "text": ["I", "B", "M"]}]}
136
+ {"text": ["現", "在", "は", "ア", "ニ", "メ", "ー", "シ", "ョ", "ン", "業", "界", "か", "ら", "退", "い", "て", "お", "り", "、", "水", "彩", "画", "家", "と", "し", "て", "も", "活", "動", "し", "て", "い", "る", "。"], "entities": []}
137
+ {"text": ["大", "野", "東", "イ", "ン", "タ", "ー", "チ", "ェ", "ン", "ジ", "は", "、", "大", "分", "県", "豊", "後", "大", "野", "市", "大", "野", "町", "後", "田", "に", "あ", "る", "中", "九", "州", "横", "断", "道", "路", "の", "イ", "ン", "タ", "ー", "チ", "ェ", "ン", "ジ", "で", "あ", "る", "。"], "entities": [{"start": 0, "end": 11, "entity": "施設名", "text": ["大", "野", "東", "イ", "ン", "タ", "ー", "チ", "ェ", "ン", "ジ"]}, {"start": 13, "end": 26, "entity": "地名", "text": ["大", "分", "県", "豊", "後", "大", "野", "市", "大", "野", "町", "後", "田"]}, {"start": 29, "end": 36, "entity": "施設名", "text": ["中", "九", "州", "横", "断", "道", "路"]}]}
138
+ {"text": ["2", "0", "1", "4", "年", "1", "月", "1", "5", "日", "、", "マ", "バ", "タ", "は", "ミ", "ャ", "ン", "マ", "ー", "の", "上", "座", "部", "仏", "教", "を", "擁", "護", "す", "る", "使", "命", "を", "持", "っ", "て", "、", "マ", "ン", "ダ", "レ", "ー", "の", "仏", "教", "僧", "の", "大", "規", "模", "な", "会", "議", "で", "正", "式", "に", "設", "立", "さ", "れ", "た", "。"], "entities": [{"start": 11, "end": 14, "entity": "法人名", "text": ["マ", "バ", "タ"]}, {"start": 15, "end": 20, "entity": "地名", "text": ["ミ", "ャ", "ン", "マ", "ー"]}, {"start": 38, "end": 43, "entity": "地名", "text": ["マ", "ン", "ダ", "レ", "ー"]}]}
139
+ {"text": ["永", "泰", "荘", "駅", "は", "、", "中", "華", "人", "民", "共", "和", "国", "北", "京", "市", "海", "淀", "区", "に", "位", "置", "す", "る", "北", "京", "地", "下", "鉄", "8", "号", "線", "の", "駅", "で", "あ", "る", "。"], "entities": [{"start": 0, "end": 4, "entity": "施設名", "text": ["永", "泰", "荘", "駅"]}, {"start": 6, "end": 19, "entity": "地名", "text": ["中", "華", "人", "民", "共", "和", "国", "北", "京", "市", "海", "淀", "区"]}]}
140
+ ```
141
+
142
+ - Korean
143
+
144
+ ```jsonl
145
+ {"text": ["그", "모습", "을", "보", "ㄴ", "민이", "는", "할아버지", "가", "마치", "전쟁터", "에서", "이기", "고", "돌아오", "ㄴ", "장군", "처럼", "의젓", "하", "아", "보이", "ㄴ다고", "생각", "하", "았", "습니다", "."], "entities": [{"start": 5, "end": 6, "entity": "PS", "text": ["민이"]}]}
146
+ {"text": ["내달", "18", "일", "부터", "내년", "2", "월", "20", "일", "까지", "는", "서울역", "에서", "무주리조트", "부근", "까지", "스키관광", "열차", "를", "운행", "하", "ㄴ다", "."], "entities": [{"start": 0, "end": 10, "entity": "DT", "text": ["내달", "18", "일", "부터", "내년", "2", "월", "20", "일", "까지"]}, {"start": 11, "end": 12, "entity": "LC", "text": ["서울역"]}, {"start": 13, "end": 14, "entity": "OG", "text": ["무주리조트"]}]}
147
+ {"text": ["호소력", "있", "고", "선동", "적", "이", "ㄴ", "주제", "를", "잡아내", "는", "데", "능하", "ㄴ", "즈윅", "이", "지만", "이", "영화", "에서", "는", "무엇", "이", "호소력", "이", "있", "을지", "결정", "하", "지", "못하", "고", "망설이", "ㄴ다", "."], "entities": [{"start": 14, "end": 15, "entity": "PS", "text": ["즈윅"]}]}
148
+ {"text": ["그래서", "세호", "는", "밤", "이", "면", "친구", "네", "집", "을", "돌아다니", "며", "아버지", "몰래", "연습", "을", "하", "았", "습니다", "."], "entities": [{"start": 1, "end": 2, "entity": "PS", "text": ["세호"]}, {"start": 3, "end": 4, "entity": "TI", "text": ["밤"]}]}
149
+ {"text": ["황씨", "는", "자신", "이", "어리", "어서", "듣", "은", "이", "이야기", "가", "어린이", "들", "에게", "소박", "하", "ㄴ", "효자", "의", "마음", "을", "전하", "아", "주", "ㄹ", "수", "있", "을", "것", "같", "아", "5", "분", "짜리", "구연동화", "로", "각색", "하", "았", "다고", "말", "하", "ㄴ다", "."], "entities": [{"start": 0, "end": 1, "entity": "PS", "text": ["황씨"]}, {"start": 31, "end": 33, "entity": "TI", "text": ["5", "분"]}]}
150
+ {"text": ["아버지", "가", "돌아가", "시", "ㄴ", "뒤", "어머니", "의", "편애", "를", "배경", "으로", "승주", "는", "집안", "에서", "만", "은", "대단", "하", "ㄴ", "권세", "를", "누리", "었", "다", "."], "entities": [{"start": 12, "end": 13, "entity": "PS", "text": ["승주"]}]}
151
+ ```
152
+
153
+ **Labels**
154
+
155
+ - `.txt`
156
+
157
+ ```
158
+ O
159
+ GPE.NAM
160
+ GPE.NOM
161
+ LOC.NAM
162
+ LOC.NOM
163
+ ORG.NAM
164
+ ORG.NOM
165
+ PER.NAM
166
+ PER.NOM
167
+ ```
168
+
169
+ - `.json` / `.jsonl`
170
+
171
+ ```json
172
+ {
173
+ "O": {
174
+ "idx": 0,
175
+ "count": -1,
176
+ "is_target": true
177
+ },
178
+ "GPE.NAM": {
179
+ "idx": 1,
180
+ "count": -1,
181
+ "is_target": true
182
+ },
183
+ "GPE.NOM": {
184
+ "idx": 2,
185
+ "count": -1,
186
+ "is_target": true
187
+ },
188
+ "LOC.NAM": {
189
+ "idx": 3,
190
+ "count": -1,
191
+ "is_target": true
192
+ },
193
+ "LOC.NOM": {
194
+ "idx": 4,
195
+ "count": -1,
196
+ "is_target": true
197
+ },
198
+ "ORG.NAM": {
199
+ "idx": 5,
200
+ "count": -1,
201
+ "is_target": true
202
+ },
203
+ "ORG.NOM": {
204
+ "idx": 6,
205
+ "count": -1,
206
+ "is_target": true
207
+ },
208
+ "PER.NAM": {
209
+ "idx": 7,
210
+ "count": -1,
211
+ "is_target": true
212
+ },
213
+ "PER.NOM": {
214
+ "idx": 8,
215
+ "count": -1,
216
+ "is_target": true
217
+ },
218
+ "ADJECTIVE": {
219
+ "idx": 9,
220
+ "count": 1008,
221
+ "is_target": false
222
+ },
223
+ "ADPOSITION": {
224
+ "idx": 10,
225
+ "count": 41,
226
+ "is_target": false
227
+ },
228
+ "ADVERB": {
229
+ "idx": 11,
230
+ "count": 1147,
231
+ "is_target": false
232
+ },
233
+ "APP": {
234
+ "idx": 12,
235
+ "count": 3,
236
+ "is_target": false
237
+ },
238
+ "AUXILIARY": {
239
+ "idx": 13,
240
+ "count": 4,
241
+ "is_target": false
242
+ },...
243
+ }
244
+ ```
245
+
246
+ * **Model**: BERT / RoBERTa
247
+
248
+ ```python
249
+ from main.trainers.knowfree_trainer import Trainer
250
+ from transformers import BertTokenizer, BertConfig
251
+
252
+ MODEL_PATH = "<MODEL_PATH>"
253
+ tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
254
+ config = BertConfig.from_pretrained(MODEL_PATH)
255
+ trainer = Trainer(tokenizer=tokenizer, config=config, from_pretrained=MODEL_PATH,
256
+ data_name='<DATASET_NAME>',
257
+ batch_size=4,
258
+ batch_size_eval=8,
259
+ task_name='<TASK_NAME>')
260
+
261
+ for i in trainer(num_epochs=120, other_lr=1e-3, weight_decay=0.01, remove_clashed=True, nested=False, eval_call_step=lambda x: x % 125 == 0):
262
+ a = i
263
+ ```
264
+
265
+ **Key Params**
266
+
267
+ - `other_lr`: the learning rate of the non-PLM part.
268
+ - `remove_clashed`: remove the label that exists overlap (only choose the label with min start position)
269
+ - `nested`: whether support nested entities, when do sequence labeling like `CMeEE`, you should set it as true and disabled `remove_clashed`.
270
+ - `eval_call_step`: determine evaluation with `x` steps, defined with a function call.
271
+
272
+ #### Evaluation Only
273
+
274
+ Comment out the training loop to evaluate directly:
275
+
276
+ ```python
277
+ trainer.eval(0, is_eval=True)
278
+ ```
279
+
280
+ ### Train with `CNN Nested NER`
281
+
282
+ ```python
283
+ from main.trainers.cnnner_trainer import Trainer
284
+ from transformers import BertTokenizer, BertConfig
285
+
286
+ MODEL_PATH = "<MODEL_PATH>"
287
+ tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
288
+ config = BertConfig.from_pretrained(MODEL_PATH)
289
+ trainer = Trainer(tokenizer=tokenizer, config=config, from_pretrained=MODEL_PATH,
290
+ data_name='<DATASET_NAME>',
291
+ batch_size=4,
292
+ batch_size_eval=8,
293
+ task_name='<TASK_NAME>')
294
+
295
+ for i in trainer(num_epochs=120, other_lr=1e-3, weight_decay=0.01, remove_clashed=True, nested=False, eval_call_step=lambda x: x % 125 == 0):
296
+ a = i
297
+ ```
298
+
299
+ #### Prediction
300
+
301
+ ```python
302
+ from main.predictor.knowfree_predictor import KnowFREEPredictor
303
+ from transformers import BertTokenizer, BertConfig
304
+
305
+ MODEL_PATH = "<MODEL_PATH>"
306
+ LABEL_FILE = '<LABEL_PATH>'
307
+ tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
308
+ config = BertConfig.from_pretrained(MODEL_PATH)
309
+ pred = KnowFREEPredictor(tokenizer=tokenizer, config=config, from_pretrained=MODEL_PATH, label_file=LABEL_FILE, batch_size=4)
310
+
311
+ for entities in pred(['叶赟葆:全球时尚财运滚滚而来钱', '我要去我要去花心花心花心耶分手大师贵仔邓超四大名捕围观话筒转发邓超贴吧微博号外话筒望周知。邓超四大名捕']):
312
+ print(entities)
313
+ ```
314
+
315
+ **Result**
316
+
317
+ ```json
318
+ [
319
+ [
320
+ {'start': 0, 'end': 3, 'entity': 'PER.NAM', 'text': ['叶', '赟', '葆'
321
+ ]
322
+ }
323
+ ],
324
+ [
325
+ {'start': 45, 'end': 47, 'entity': 'PER.NAM', 'text': ['邓', '超'
326
+ ]
327
+ },
328
+ {'start': 19, 'end': 21, 'entity': 'PER.NAM', 'text': ['邓', '超'
329
+ ]
330
+ },
331
+ {'start': 31, 'end': 33, 'entity': 'PER.NAM', 'text': ['邓', '超'
332
+ ]
333
+ }
334
+ ]
335
+ ]
336
+ ```
337
+
338
+ ## 📚 Citation
339
+ ```bibtex
340
+ @misc{lai2025improvinglowresourcesequencelabeling,
341
+ title={Improving Low-Resource Sequence Labeling with Knowledge Fusion and Contextual Label Explanations},
342
+ author={Peichao Lai and Jiaxin Gan and Feiyang Ye and Yilei Wang and Bin Cui},
343
+ year={2025},
344
+ eprint={2501.19093},
345
+ archivePrefix={arXiv},
346
+ primaryClass={cs.CL},
347
+ url={https://arxiv.org/abs/2501.19093},
348
+ }
349
+ ```
config.json ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/lpc/models/chinese-bert-wwm-ext/",
3
+ "architectures": [
4
+ "CNNNerv1"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "biaffine_size": 200,
8
+ "classifier_dropout": null,
9
+ "cnn_depth": 3,
10
+ "cnn_dim": 200,
11
+ "directionality": "bidi",
12
+ "hidden_act": "gelu",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 768,
15
+ "id2label": {
16
+ "0": "LABEL_0",
17
+ "1": "LABEL_1",
18
+ "2": "LABEL_2",
19
+ "3": "LABEL_3",
20
+ "4": "LABEL_4",
21
+ "5": "LABEL_5",
22
+ "6": "LABEL_6",
23
+ "7": "LABEL_7",
24
+ "8": "LABEL_8",
25
+ "9": "LABEL_9",
26
+ "10": "LABEL_10",
27
+ "11": "LABEL_11",
28
+ "12": "LABEL_12",
29
+ "13": "LABEL_13",
30
+ "14": "LABEL_14",
31
+ "15": "LABEL_15",
32
+ "16": "LABEL_16",
33
+ "17": "LABEL_17",
34
+ "18": "LABEL_18",
35
+ "19": "LABEL_19",
36
+ "20": "LABEL_20",
37
+ "21": "LABEL_21",
38
+ "22": "LABEL_22",
39
+ "23": "LABEL_23",
40
+ "24": "LABEL_24",
41
+ "25": "LABEL_25",
42
+ "26": "LABEL_26",
43
+ "27": "LABEL_27"
44
+ },
45
+ "initializer_range": 0.02,
46
+ "intermediate_size": 3072,
47
+ "kernel_size": 3,
48
+ "label2id": {
49
+ "LABEL_0": 0,
50
+ "LABEL_1": 1,
51
+ "LABEL_10": 10,
52
+ "LABEL_11": 11,
53
+ "LABEL_12": 12,
54
+ "LABEL_13": 13,
55
+ "LABEL_14": 14,
56
+ "LABEL_15": 15,
57
+ "LABEL_16": 16,
58
+ "LABEL_17": 17,
59
+ "LABEL_18": 18,
60
+ "LABEL_19": 19,
61
+ "LABEL_2": 2,
62
+ "LABEL_20": 20,
63
+ "LABEL_21": 21,
64
+ "LABEL_22": 22,
65
+ "LABEL_23": 23,
66
+ "LABEL_24": 24,
67
+ "LABEL_25": 25,
68
+ "LABEL_26": 26,
69
+ "LABEL_27": 27,
70
+ "LABEL_3": 3,
71
+ "LABEL_4": 4,
72
+ "LABEL_5": 5,
73
+ "LABEL_6": 6,
74
+ "LABEL_7": 7,
75
+ "LABEL_8": 8,
76
+ "LABEL_9": 9
77
+ },
78
+ "layer_norm_eps": 1e-12,
79
+ "logit_drop": 0,
80
+ "max_position_embeddings": 512,
81
+ "model_type": "bert",
82
+ "n_head": 4,
83
+ "num_attention_heads": 12,
84
+ "num_hidden_layers": 12,
85
+ "output_past": true,
86
+ "pad_token_id": 0,
87
+ "pooler_fc_size": 768,
88
+ "pooler_num_attention_heads": 12,
89
+ "pooler_num_fc_layers": 3,
90
+ "pooler_size_per_head": 128,
91
+ "pooler_type": "first_token_transform",
92
+ "position_embedding_type": "absolute",
93
+ "size_embed_dim": 25,
94
+ "span_threshold": 0.5,
95
+ "torch_dtype": "float32",
96
+ "transformers_version": "4.18.0",
97
+ "type_vocab_size": 2,
98
+ "use_cache": true,
99
+ "vocab_size": 21128
100
+ }
enhance_data_info/entity_train_1000.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
enhance_data_info/entity_train_1000_synthetic.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
enhance_data_info/pos_train_1000.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
enhance_data_info/pos_train_1000_synthetic.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
enhance_data_info/train_1000_fusion.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
enhance_data_info/train_1000_synthetic.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
fusion_ner.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch_scatter import scatter_max
5
+ import numpy as np
6
+ from transformers import BertModel, BertPreTrainedModel, BertConfig
7
+ from fastNLP import seq_len_to_mask
8
+ from typing import Tuple
9
+
10
+
11
+ class CNNNerv1(BertPreTrainedModel):
12
+ def __init__(self, config: BertConfig):
13
+ # model_name, num_ner_tag, cnn_dim=200, biaffine_size=200,
14
+ # size_embed_dim=0, logit_drop=0, kernel_size=3, n_head=4, cnn_depth=3):
15
+ super().__init__(config)
16
+ self.hidden_size = config.hidden_size
17
+ self.size_embed_dim = config.size_embed_dim
18
+ self.cnn_dim = config.cnn_dim
19
+ self.biaffine_size = config.biaffine_size
20
+ self.logit_drop = config.logit_drop
21
+ self.kernel_size = config.kernel_size
22
+ self.n_head = config.n_head
23
+ self.cnn_depth = config.cnn_depth
24
+ self.num_labels = config.num_labels
25
+ self.span_threshold = config.span_threshold
26
+ self.ext_labels_start_idx = 8
27
+
28
+ self.bert = BertModel(config, add_pooling_layer=False)
29
+
30
+ if self.size_embed_dim != 0:
31
+ n_pos = 30 # span 跨度位置编码为-n_pos到n_pos之间
32
+ self.size_embedding = torch.nn.Embedding(
33
+ n_pos, self.size_embed_dim)
34
+ # `512 - 512`: 这两个生成的张量相减,得到一个矩阵,每个元素代表两个位置之间的距离(跨度)。
35
+ # e.g. [[0,1,2,...,512]
36
+ # [-1,0,1,...,511]
37
+ # [...]
38
+ # [-511,-510,...,0]]
39
+ _span_size_ids = torch.arange(
40
+ 512) - torch.arange(512).unsqueeze(-1)
41
+ # 限制span最大距离为pos / 2
42
+ _span_size_ids.masked_fill_(_span_size_ids < -n_pos/2, -n_pos/2)
43
+ _span_size_ids = _span_size_ids.masked_fill(
44
+ _span_size_ids >= n_pos/2, n_pos/2-1) + n_pos/2
45
+ # 注册为非更新参数
46
+ self.register_buffer('span_size_ids', _span_size_ids.long())
47
+ hsz = self.biaffine_size*2 + self.size_embed_dim + 2
48
+ else:
49
+ hsz = self.biaffine_size*2+2
50
+ biaffine_input_size = self.hidden_size
51
+
52
+ self.head_mlp = nn.Sequential(
53
+ nn.Dropout(0.4),
54
+ nn.Linear(biaffine_input_size, self.biaffine_size),
55
+ nn.LeakyReLU(),
56
+ )
57
+ self.tail_mlp = nn.Sequential(
58
+ nn.Dropout(0.4),
59
+ nn.Linear(biaffine_input_size, self.biaffine_size),
60
+ nn.LeakyReLU(),
61
+ )
62
+
63
+ self.dropout = nn.Dropout(0.4)
64
+ if self.n_head > 0:
65
+ self.multi_head_biaffine = MultiHeadBiaffine(
66
+ self.biaffine_size, self.cnn_dim, n_head=self.n_head)
67
+ else:
68
+ self.U = nn.Parameter(torch.randn(
69
+ self.cnn_dim, self.biaffine_size, self.biaffine_size))
70
+ torch.nn.init.xavier_normal_(self.U.data)
71
+ self.W = torch.nn.Parameter(torch.empty(self.cnn_dim, hsz))
72
+ torch.nn.init.xavier_normal_(self.W.data)
73
+ if self.cnn_depth > 0:
74
+ self.cnn = MaskCNN(self.cnn_dim, self.cnn_dim,
75
+ kernel_size=self.kernel_size, depth=self.cnn_depth)
76
+ self.attn = LocalAttentionModel(self.cnn_dim, self.kernel_size)
77
+
78
+ self.down_fc = nn.Linear(self.cnn_dim, self.num_labels-1)
79
+ self.logit_drop = self.logit_drop
80
+
81
+ def decode_labels(self, labels: torch.Tensor, indexes: torch.Tensor):
82
+ # 这里的labels不含有特殊的字符,因此不需要减去offset
83
+ length: np.ndarray = indexes.detach().cpu().numpy()
84
+ length = length.max(-1)
85
+ labels[:, :, :, self.ext_labels_start_idx:] = 0
86
+ labels: np.ndarray = labels.detach().cpu().numpy()
87
+ span_mask = (labels.max(-1) > self.span_threshold)
88
+ labels = labels.argmax(-1)
89
+ indexes = np.where(span_mask)
90
+ entities = [set() for _ in range(labels.shape[0])]
91
+ for batch, x, y in zip(*indexes):
92
+ if x <= y and x >= 0 and y >= 0 and x < length[batch] and y < length[batch]:
93
+ entities[batch].add(
94
+ (x, y, labels[batch, x, y] + 1)) # +1 是由于有O标签
95
+ return entities
96
+
97
+ def is_span_intersect(self, a: Tuple[int, int], b: Tuple[int, int]):
98
+ """
99
+ 判断两个区间是否相交,左右都是闭区间
100
+ """
101
+ return a[0] <= b[1] and b[0] <= a[1]
102
+
103
+ def is_span_nested(self, a: Tuple[int, int], b: Tuple[int, int]):
104
+ """
105
+ 判断两个区间是否嵌套,左右都是闭区间
106
+ """
107
+ return (b[0] <= a[0] and a[1] <= b[1]) or (a[0] <= b[0] and b[1] <= a[1])
108
+
109
+ def decode_logits(self, scores: torch.Tensor, indexes: torch.Tensor, remove_clashed: bool = False, nested: bool = True):
110
+ scores = scores.sigmoid()
111
+ # 这里的scores也是没有特殊字符的
112
+ # 按照论文代码里的解码方式是上下三角取平均
113
+ # scores = (scores.transpose(1, 2) + scores)/2
114
+ scores: np.ndarray = scores.detach().cpu().numpy()
115
+
116
+ length: np.ndarray = indexes.detach().cpu().numpy()
117
+ length = length.max(-1)
118
+
119
+ scores[:, :, :, self.ext_labels_start_idx:] = 0
120
+ span_mask = (scores.max(-1) > self.span_threshold)
121
+ argmax = scores.argmax(-1)
122
+ indexes = np.where(span_mask)
123
+ entities = [[] for _ in range(scores.shape[0])]
124
+ # 同labels一样没有特殊的标签
125
+ # 将预测实体append到entities中
126
+ for batch_idx, x, y in zip(*indexes):
127
+ if x >= 0 and x < length[batch_idx] and y >= 0 and y < length[batch_idx] and x <= y:
128
+ # (start, end, label_idx, score)
129
+ entities[batch_idx].append(
130
+ (x, y, argmax[batch_idx, x, y] + 1, scores[batch_idx, x, y, argmax[batch_idx, x, y]]))
131
+ # 对每一个batch, 按label_score的降序排列
132
+ for batch_idx in range(len(entities)):
133
+ entities[batch_idx].sort(key=lambda x: x[-1], reverse=True)
134
+ if remove_clashed:
135
+ for batch_idx in range(len(entities)):
136
+ new_entities = []
137
+ for entity in entities[batch_idx]:
138
+ add = True
139
+ for pre_entity in new_entities:
140
+ if self.is_span_intersect(entity, pre_entity) and (not nested or not self.is_span_nested(entity, pre_entity)):
141
+ add = False
142
+ break
143
+ if add:
144
+ new_entities.append(entity)
145
+ entities[batch_idx] = new_entities
146
+ # 转换为set
147
+ for batch_idx in range(len(entities)):
148
+ entities[batch_idx] = set(
149
+ map(lambda x: (x[0], x[1], x[2]), entities[batch_idx]))
150
+ return entities
151
+
152
+ def forward(self, input_ids: torch.Tensor, bpe_len: torch.Tensor, indexes: torch.Tensor, labels: torch.Tensor = None, is_synthetic: torch.Tensor = None, **kwargs):
153
+ # input_ids 就是常规的input_ids, [batch_size, seq_length, hidden_dim]
154
+ # bpe_len 是flat tokens和[CLS]和[SEP]的长度, 不包括[PAD] [batch_size]
155
+ # indexes 是每个字的坐标[0,1,...], [batch_size, seq_length, hidden_dim]
156
+ # matrix [batch_size, seq_length, seq_length, num_labels] 的0,1矩阵
157
+ attention_mask = seq_len_to_mask(bpe_len) # bsz x length x length
158
+ outputs = self.bert(
159
+ input_ids, attention_mask=attention_mask, return_dict=True)
160
+ last_hidden_states = outputs['last_hidden_state']
161
+ # 这里的效果其实跟W2NER是一样的,就是pieces2word
162
+ # 所有index为0的标签会被选取包含最大的hidden_dim的token, 放置在第0位, 即[CLS], [SEP]和[PAD]的标签
163
+ # 所有index相同的标签会被选取包含最大的hidden_dim的token, 放置在第index位
164
+ # 其余位置补0
165
+ # WARN: 这里会去除前后两个token,因此labels要提前去除前后两个token
166
+ state = scatter_max(last_hidden_states, index=indexes, dim=1)[
167
+ 0][:, 1:] # bsz x word_len x hidden_size
168
+ # 真实的文本-标签对长度
169
+ lengths, _ = indexes.max(dim=-1)
170
+
171
+ # 1. state先传进head和tail的MLP压一下维度得到头尾特征
172
+ head_state = self.head_mlp(state)
173
+ tail_state = self.tail_mlp(state)
174
+
175
+ # 2. 进单头还是多头
176
+ if hasattr(self, 'U'):
177
+ scores1 = torch.einsum(
178
+ 'bxi, oij, byj -> boxy', head_state, self.U, tail_state) # [batch_size, out_dim , word_len, word_len]
179
+ else:
180
+ # [batch_size, out_dim, word_len, word_len]
181
+ scores1 = self.multi_head_biaffine(head_state, tail_state)
182
+
183
+ # 3. head 和 tail 自我扩展成word_len*2后将hidden_state拼接并加入偏置项和相对距离positional embedding.
184
+ head_state = torch.cat(
185
+ [head_state, torch.ones_like(head_state[..., :1])], dim=-1)
186
+ tail_state = torch.cat(
187
+ [tail_state, torch.ones_like(tail_state[..., :1])], dim=-1)
188
+ affined_cat = torch.cat([self.dropout(head_state).unsqueeze(2).expand(-1, -1, tail_state.size(1), -1),
189
+ self.dropout(tail_state).unsqueeze(1).expand(-1, head_state.size(1), -1, -1)], dim=-1)
190
+
191
+ if hasattr(self, 'size_embedding'):
192
+ size_embedded = self.size_embedding(
193
+ self.span_size_ids[:state.size(1), :state.size(1)])
194
+ affined_cat = torch.cat([affined_cat,
195
+ self.dropout(size_embedded).unsqueeze(0).expand(state.size(0), -1, -1, -1)], dim=-1)
196
+
197
+ scores2 = torch.einsum('bmnh,kh->bkmn', affined_cat,
198
+ self.W) # bsz x dim x L x L
199
+ scores = scores2 + scores1 # bsz x dim x L x L
200
+
201
+ if hasattr(self, 'cnn'):
202
+ mask = seq_len_to_mask(lengths) # bsz x length x length
203
+ mask = mask[:, None] * mask.unsqueeze(-1)
204
+ pad_mask = mask[:, None].eq(0)
205
+ u_scores = scores.masked_fill(pad_mask, 0)
206
+ if self.logit_drop != 0:
207
+ u_scores = F.dropout(
208
+ u_scores, p=self.logit_drop, training=self.training)
209
+ # bsz, num_label, max_len, max_len = u_scores.size()
210
+ # u_scores = self.cnn(u_scores, pad_mask)
211
+ u_scores = self.attn(u_scores.permute(0, 2, 3, 1), pad_mask=pad_mask.permute(0, 2, 3, 1))
212
+ scores = u_scores.permute(0, 3, 1, 2) + scores
213
+
214
+ # 把dim作为尾部对准fc
215
+ scores = self.down_fc(scores.permute(0, 2, 3, 1))
216
+
217
+ assert scores.size(-1) == labels.size(-1)
218
+
219
+ loss = None
220
+ if labels is not None:
221
+ flat_scores = scores.reshape(-1)
222
+ flat_matrix = labels.reshape(-1)
223
+ decay_weights = torch.ones(labels.size()).to(flat_matrix.device)
224
+ decay_weights[:, :, :, self.ext_labels_start_idx:] *= 0.13
225
+ decayed_weights = decay_weights.reshape(input_ids.size(0), -1)
226
+ synthetic_mask = torch.ones(labels.size()).to(flat_matrix.device)
227
+ synthetic_mask[:, is_synthetic] *= 0.15
228
+ synthetic_weights = synthetic_mask.reshape(input_ids.size(0), -1)
229
+ mask = flat_matrix.ne(-100).float().view(input_ids.size(0), -1)
230
+ flat_loss = F.binary_cross_entropy_with_logits(
231
+ flat_scores, flat_matrix.float(), reduction='none')
232
+ loss = ((flat_loss.view(input_ids.size(0), -1)*synthetic_weights*decayed_weights*mask).sum(dim=-1)).mean()
233
+
234
+ return loss, scores
235
+
236
+ class LocalSpanAttention(nn.Module):
237
+ def __init__(self, dim):
238
+ super(LocalSpanAttention, self).__init__()
239
+ self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=10)
240
+
241
+ def forward(self, x, span_mask):
242
+ """
243
+ :param x: [bsz, len, len, dim] 输入特征矩阵
244
+ :param span_mask: [bsz, len, len] mask矩阵,用于控制attention的感受野
245
+ """
246
+ bsz, length, _, dim = x.shape # 获取输入的形状
247
+
248
+ # 将输入 reshape 为 [bsz * len, len, dim],使其适合 MultiheadAttention 操作
249
+ x = x.view(bsz * length, length, dim) # 展平前两维,准备进行 attention
250
+
251
+ # 转换为 [len, bsz * len, dim],用于MultiheadAttention
252
+ x = x.transpose(0, 1)
253
+
254
+ # 注意力计算时需要传入mask
255
+ attention_output, _ = self.attn(x, x, x, attn_mask=span_mask)
256
+
257
+ # 恢复为 [bsz * len, len, dim] 的形状
258
+ attention_output = attention_output.transpose(0, 1).view(bsz, length, length, dim)
259
+
260
+ return attention_output
261
+
262
+ class LocalAttentionModel(nn.Module):
263
+ def __init__(self, dim, window_size=3):
264
+ super(LocalAttentionModel, self).__init__()
265
+ self.local_attention = LocalSpanAttention(dim)
266
+ self.norm = nn.LayerNorm(dim)
267
+ self.window_size = window_size
268
+
269
+ def generate_local_mask(self, seq_len, window_size):
270
+ # 构建局部注意力的 mask,只允许相邻的 token 进行交互
271
+ mask = torch.full((seq_len, seq_len), float('-inf')) # 初始化为全 -inf
272
+ for i in range(seq_len):
273
+ start = max(0, i - window_size)
274
+ end = min(seq_len, i + window_size + 1)
275
+ mask[i, start:end] = 0 # 允许局部的 token 进行交互
276
+ return mask
277
+
278
+ def forward(self, x, pad_mask):
279
+ """
280
+ :param x: [bsz, len, len, dim] 输入特征
281
+ """
282
+ bsz, length, _, dim = x.shape
283
+
284
+ # 生成局部 mask,控制每个 span 的注意力范围
285
+ local_mask = self.generate_local_mask(length, self.window_size)
286
+ local_mask = local_mask.to(x.device) # 确保 mask 和输入在同一设备上
287
+
288
+ # 对每个样本的局部span进行 attention
289
+ x = x.masked_fill(pad_mask, 0)
290
+ attn_output = self.local_attention(x, local_mask)
291
+
292
+ # 使用 LayerNorm 进行正则化
293
+ output = self.norm(attn_output)
294
+
295
+ return output
296
+
297
+
298
+ class LayerNorm(nn.Module):
299
+ def __init__(self, shape=(1, 7, 1, 1), dim_index=1):
300
+ super(LayerNorm, self).__init__()
301
+ self.weight = nn.Parameter(torch.ones(shape))
302
+ self.bias = nn.Parameter(torch.zeros(shape))
303
+ self.dim_index = dim_index
304
+ self.eps = 1e-6
305
+
306
+ def forward(self, x):
307
+ """
308
+
309
+ :param x: bsz x dim x max_len x max_len
310
+ :param mask: bsz x dim x max_len x max_len, 为1的地方为pad
311
+ :return:
312
+ """
313
+ u = x.mean(dim=self.dim_index, keepdim=True)
314
+ s = (x - u).pow(2).mean(dim=self.dim_index, keepdim=True)
315
+ x = (x - u) / torch.sqrt(s + self.eps)
316
+ x = self.weight * x + self.bias
317
+ return x
318
+
319
+
320
+ class MaskConv2d(nn.Module):
321
+ def __init__(self, in_ch, out_ch, kernel_size=3, padding=1, groups=1):
322
+ super(MaskConv2d, self).__init__()
323
+ self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding,
324
+ bias=False, groups=groups)
325
+
326
+ def forward(self, x, mask):
327
+ """
328
+
329
+ :param x:
330
+ :param mask:
331
+ :return:
332
+ """
333
+ x = x.masked_fill(mask, 0)
334
+ _x = self.conv2d(x)
335
+ return _x
336
+
337
+
338
+ class MaskCNN(nn.Module):
339
+ def __init__(self, input_channels, output_channels, kernel_size=3, depth=3):
340
+ super(MaskCNN, self).__init__()
341
+
342
+ layers = []
343
+ for i in range(depth):
344
+ layers.extend([
345
+ MaskConv2d(input_channels, input_channels,
346
+ kernel_size=kernel_size, padding=kernel_size//2),
347
+ LayerNorm((1, input_channels, 1, 1), dim_index=1),
348
+ nn.GELU()])
349
+ layers.append(MaskConv2d(input_channels, output_channels,
350
+ kernel_size=3, padding=3//2))
351
+ self.cnns = nn.ModuleList(layers)
352
+
353
+ def forward(self, x, mask):
354
+ _x = x # 用作residual
355
+ for layer in self.cnns:
356
+ if isinstance(layer, LayerNorm):
357
+ x = x + _x
358
+ x = layer(x)
359
+ _x = x
360
+ elif not isinstance(layer, nn.GELU):
361
+ x = layer(x, mask)
362
+ else:
363
+ x = layer(x)
364
+ return _x
365
+
366
+
367
+ class MultiHeadBiaffine(nn.Module):
368
+ def __init__(self, dim, out=None, n_head=4):
369
+ super(MultiHeadBiaffine, self).__init__()
370
+ assert dim % n_head == 0
371
+ in_head_dim = dim//n_head
372
+ out = dim if out is None else out
373
+ assert out % n_head == 0
374
+ out_head_dim = out//n_head
375
+ self.n_head = n_head
376
+ self.W = nn.Parameter(nn.init.xavier_normal_(torch.randn(
377
+ self.n_head, out_head_dim, in_head_dim, in_head_dim)))
378
+ self.out_dim = out
379
+
380
+ def forward(self, h, v):
381
+ """
382
+
383
+ :param h: bsz x max_len x dim
384
+ :param v: bsz x max_len x dim
385
+ :return: bsz x max_len x max_len x out_dim
386
+ """
387
+ bsz, max_len, dim = h.size()
388
+ h = h.reshape(bsz, max_len, self.n_head, -1)
389
+ v = v.reshape(bsz, max_len, self.n_head, -1)
390
+ # b: bsz, l: seq_len, h: head_num, x: in_head_dim, y: In_head_dim, d: out_head_dim, k: out_dim
391
+ w = torch.einsum('blhx,hdxy,bkhy->bhdlk', h, self.W, v)
392
+ # [batch_size, out_dim, seq_len, seq_len]
393
+ w = w.reshape(bsz, self.out_dim, max_len, max_len)
394
+ return w
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb3fe79261f6683f838a0de670ca8425e9536826423cd6cc91c24851ff4e422c
3
+ size 418892592
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "special_tokens_map_file": "/home/lpc/models/chinese-bert-wwm-ext/special_tokens_map.json", "name_or_path": "/home/lpc/models/chinese-bert-wwm-ext/", "tokenizer_class": "BertTokenizer"}
vocab.txt ADDED
The diff for this file is too large to render. See raw diff