| import os.path as osp |
| import numpy as np |
| import random |
| import torch |
| from easydict import EasyDict as edict |
| import argparse |
| import pdb |
| import json |
| from model import BertTokenizer |
| from collections import Counter |
| from ltp import LTP |
| from tqdm import tqdm |
| from src.utils import add_special_token |
| from functools import reduce |
| from time import time |
| from numpy import mean |
| import math |
|
|
| from src.utils import Loss_log, time_trans |
| from collections import defaultdict |
|
|
|
|
| class cfg(): |
| def __init__(self): |
| self.this_dir = osp.dirname(__file__) |
| |
| self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', '')) |
|
|
| def get_args(self): |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument("--data_path", default="huawei", type=str, help="Experiment path") |
| |
| parser.add_argument("--freq", default=50, type=int, help="出现多少次的词认为是重要的") |
| parser.add_argument("--batch_size", default=100, type=int, help="分词的batch size") |
| parser.add_argument("--seq_data_name", default='Seq_data_large', type=str, help="seq_data 名字") |
| parser.add_argument("--deal_numeric", default=0, type=int, help="是否处理数值数据") |
|
|
| parser.add_argument("--read_cws", default=0, type=int, help="是否需要读训练好的cws文件") |
| self.cfg = parser.parse_args() |
|
|
| def update_train_configs(self): |
| |
| self.cfg.data_root = self.data_root |
| self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path) |
|
|
| return self.cfg |
|
|
|
|
| def refresh_data(ref, freq, special_token): |
| ''' |
| 功能:在自定义的special token基础上基于最小出现频率得到更多新词分词系统的参考,作为wwm基础 |
| 输入: |
| freq: 在(37万)语义词典中的最小出现频率(空格为分词) |
| special_token: 前面手工定义的特殊token(可能存在交集) |
| 输出: |
| add_words:在定义的最小出现频率基础上筛选出来的新词 |
| ''' |
| |
| seq_sub_data = [line.split() for line in ref] |
| all_data = [] |
| for data in seq_sub_data: |
| all_data.extend(data) |
| sub_word_times = dict(Counter(all_data)) |
| asub_word_time_order = sorted(sub_word_times.items(), key=lambda x: x[1], reverse=True) |
| |
| |
| add_words = [] |
|
|
| for i in asub_word_time_order: |
| |
| if i[1] >= freq and len(i[0]) > 1 and len(i[0]) < 20 and not str.isdigit(i[0]): |
| add_words.append(i[0]) |
| add_words.extend(special_token) |
| |
| print(f"[{len(add_words)}] special words will be added with frequency [{freq}]!") |
| return add_words |
|
|
|
|
| def cws(seq_data, add_words, batch_size): |
| ''' |
| 功能:所有序列数据的输入转换成分词之后的结果 |
| 输入: |
| seq_data:所有序列数据输入 e.g.['KPI异常下降', 'KPI异常上升'] |
| add_words:添加的special words |
| batch_size:每次分多少句 |
| 输出: |
| all_segment:所有序列数据的输出 e.g. [['KPI', '异常', '下降'], ['KPI', '异常', '上升']] |
| data_size:输入/输出的序列数量(e.g. 2) |
| ''' |
| |
| print(f"loading...") |
| ltp = LTP("LTP/base2") |
| |
| print(f"begin adding words ...") |
| |
| ltp.add_words(words=add_words) |
| ltp.to("cuda") |
| |
| |
| print(f"{len(add_words)} special words are added!") |
|
|
| |
| |
| |
| data_size = len(seq_data) |
| seq_data_cws = [] |
| size = int(data_size / batch_size) + 1 |
| b = 0 |
| e = b + batch_size |
| |
|
|
| log = Loss_log() |
|
|
| with tqdm(total=size) as _tqdm: |
| |
| |
| |
| error_data = [] |
| for i in range(size): |
|
|
| output = [] |
| try: |
| _output = ltp.pipeline(seq_data[b:e], tasks=["cws"]) |
| for data in _output.cws: |
| try: |
| data_out = ltp.pipeline(data, tasks=["cws"]) |
| |
| data_out_ = [] |
| for i in data_out.cws: |
| data_out_.extend([k.strip() for k in i]) |
| output.append(data_out_) |
| except: |
| print(f"二阶段分词出错!范围是:[{b}]-[{e}]") |
| error_data.append(data) |
|
|
| |
| except: |
| print(f"第一阶段分词出错!范围是:[{b}]-[{e}]") |
| error_data.append(f"第一阶段分词出错!范围是:[{b}]-[{e}]") |
| |
| seq_data_cws.extend(output) |
| b = e |
| e += batch_size |
|
|
| |
| if e >= data_size: |
| if b >= data_size: |
| break |
| e = data_size |
| _tqdm.set_description(f'from {b} to {e}:') |
| _tqdm.update(1) |
|
|
| print(f"过滤了{data_size - len(seq_data_cws)}个句子") |
|
|
| return seq_data_cws, data_size, error_data |
|
|
|
|
| def ltp_debug(ltp, op): |
| output = [] |
| for data in op: |
| data_out = ltp.pipeline(data, tasks=["cws"]) |
| |
| data_out_ = [] |
| for i in data_out.cws: |
| |
| data_out_.append(i[0].strip()) |
| |
| |
| output.append(data_out_) |
| return output |
|
|
|
|
| def deal_sub_words(subwords, special_token): |
| ''' |
| 功能:把每个word的整体内,非首字符的部分加上 '##' 前缀, special_token 不应该被mask |
| ''' |
| for i in range(len(subwords)): |
| if i == 0: |
| continue |
| if subwords[i] in special_token: |
| continue |
| if subwords[i].startswith("##"): |
| continue |
|
|
| subwords[i] = "##" + subwords[i] |
| return subwords |
|
|
|
|
| def generate_chinese_ref(seq_data_cws, special_token, deal_numeric, kpi_dic): |
| ''' |
| 输入: |
| seq_data_cws:所有序列数据的输出 e.g. [['KPI', '异常', '下降'], ['KPI', '异常', '上升']] |
| special_token:不应该被mask ['[SEP]', '[MASK]', '[ALM]', '[KPI]', '[CLS]', '[LOC]', '[EOS]', '[ENT]', '[ATTR]', '[NUM]', '|'] |
| data_size:数据量 e.g. 2 |
| 输出: |
| ww_return (whole word return):打标之后的chinese ref e.g. [['KPI', '异','##常', '下', '##降'], ['KPI', '异', '##常', '上', '##升']] |
| ''' |
| |
| data_size = len(seq_data_cws) |
| kpi_static_set = set() |
| rev_kpi_dic = dict(zip(kpi_dic.values(), kpi_dic.keys())) |
| max_len = 0 |
| sten_that_over_maxl = [] |
| with tqdm(total=data_size) as _tqdm: |
| ww_return = [] |
| ww_list = [] |
| kpi_info = [] |
| not_in_KPI = defaultdict(int) |
| for i in range(data_size): |
| _tqdm.set_description(f'checking...[{i}/{data_size}] max len: [{max_len}]') |
| orig = tokenizer.tokenize(" ".join(seq_data_cws[i])) |
|
|
| if deal_numeric: |
| |
| _kpi_info, kpi_type_list = extract_kpi(orig, kpi_dic, not_in_KPI) |
| kpi_info.append(_kpi_info) |
| kpi_static_set.update(kpi_type_list) |
|
|
| sub_total = [] |
| ww_seq_tmp = [] |
| ww_tmp = [] |
| for sub_data in seq_data_cws[i]: |
| sub = tokenizer.tokenize(sub_data) |
| sub_total.extend(sub) |
| |
| |
| ref_token = deal_sub_words(sub, special_token) |
| |
| ww_seq_tmp.extend(ref_token) |
| ww_tmp.append(ref_token) |
|
|
| if sub_total != orig: |
| print("error in match... ") |
| if len(orig) > 512: |
| print("the lenth is over the max lenth") |
| pdb.set_trace() |
|
|
| |
| |
| sz_ww_seq = len(ww_seq_tmp) |
| |
| max_len = sz_ww_seq if sz_ww_seq > max_len else max_len |
| if sz_ww_seq > 500: |
| sten_that_over_maxl.append((ww_seq_tmp, sz_ww_seq)) |
|
|
| assert len(sub_total) == sz_ww_seq |
| ww_return.append(ww_seq_tmp) |
| ww_list.append(ww_tmp) |
| |
| _tqdm.update(1) |
| |
| if deal_numeric: |
| in_kpi = [] |
| |
| for key in rev_kpi_dic.keys(): |
| if key in kpi_static_set: |
| in_kpi.append(rev_kpi_dic[key]) |
| if len(in_kpi) < len(rev_kpi_dic): |
| print(f"[{len(in_kpi)}] KPI are covered by data: {in_kpi}") |
| print(f" [{len(not_in_KPI)}] KPI无法匹配{not_in_KPI}") |
| else: |
| print("all KPI are covered!") |
| return ww_return, kpi_info, sten_that_over_maxl |
|
|
|
|
| def extract_num(seq_data_cws): |
| ''' |
| 功能:把序列中的数值信息提取出来 |
| 同时过滤 nan 数值 |
| ''' |
| num_ref = [] |
| seq_data_cws_new = [] |
| for j in range(len(seq_data_cws)): |
| num_index = [i for i, x in enumerate(seq_data_cws[j]) if x == '[NUM]'] |
| |
| kpi_score = [] |
| flag = 1 |
| for index in num_index: |
| |
| |
| try: |
| tmp = float(seq_data_cws[j][index + 1]) |
| except: |
| |
| flag = 0 |
| continue |
| if math.isnan(tmp): |
| flag = 0 |
| else: |
| kpi_score.append(tmp) |
|
|
| if len(num_index) > 0: |
| for index in reversed(num_index): |
| seq_data_cws[j].pop(index + 1) |
| if flag == 1: |
| num_ref.append(kpi_score) |
| seq_data_cws_new.append(seq_data_cws[j]) |
| return seq_data_cws_new, num_ref |
|
|
|
|
| def extract_kpi(token_data, kpi_dic, not_in_KPI): |
| ''' |
| 功能:把序列中的[KPI]下标范围,[NUM]下标提取出来 |
| 输出格式: [(1,2,4),(5,6,7)] |
| ''' |
| kpi_and_num_info = [] |
| kpi_type = [] |
| kpi_index = [i for i, x in enumerate(token_data) if x.lower() == '[kpi]'] |
| num_index = [i for i, x in enumerate(token_data) if x.lower() == '[num]'] |
| sz = len(kpi_index) |
| assert sz == len(num_index) |
| for i in range(sz): |
| |
| |
| kpi_name = ''.join(token_data[kpi_index[i] + 1: num_index[i] - 1]) |
| kpi_name_clear = kpi_name.replace('##', '') |
|
|
| if kpi_name in kpi_dic: |
| kpi_id = int(kpi_dic[kpi_name]) |
| elif kpi_name_clear in kpi_dic: |
| kpi_id = int(kpi_dic[kpi_name_clear]) |
| elif kpi_name_clear in not_in_KPI: |
| kpi_id = -1 |
| not_in_KPI[kpi_name_clear] += 1 |
| else: |
| |
| not_in_KPI[kpi_name_clear] += 1 |
| kpi_id = -1 |
| |
|
|
| kpi_info = [kpi_index[i] + 1, num_index[i] - 2, num_index[i], kpi_id] |
| kpi_and_num_info.append(kpi_info) |
| kpi_type.append(kpi_id) |
| |
|
|
| return kpi_and_num_info, kpi_type |
|
|
|
|
| def kpi_combine(kpi_info, num_ref): |
| sz = len(kpi_info) |
| assert sz == len(num_ref) |
| for i in range(sz): |
| for j in range(len(kpi_info[i])): |
| kpi_info[i][j].append(num_ref[i][j]) |
| |
| return kpi_info |
|
|
| |
|
|
|
|
| def kpi_lower_update(kpi_dic): |
| new_dic = {} |
| for key in kpi_dic: |
| kk = key.lower().split() |
| kk = ''.join(kk).strip() |
| new_dic[kk] = kpi_dic[key] |
| return new_dic |
|
|
|
|
| if __name__ == '__main__': |
| ''' |
| 功能: 得到 chinese ref 文件,同时刷新训练/测试文件(仅针对序列的文本数据) |
| ''' |
| cfg = cfg() |
| cfg.get_args() |
| cfgs = cfg.update_train_configs() |
|
|
| |
| domain_file_path = osp.join(cfgs.data_path, 'special_vocab.txt') |
| with open(domain_file_path, encoding="utf-8") as f: |
| ref = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] |
| tokenizer = BertTokenizer.from_pretrained(osp.join(cfgs.data_root, 'transformer', 'MacBert'), do_lower_case=True) |
| seq_data_name = cfgs.seq_data_name |
| with open(osp.join(cfgs.data_path, f'{seq_data_name}.json'), "r") as fp: |
| seq_data = json.load(fp) |
| kpi_dic_name = 'kpi2id' |
| with open(osp.join(cfgs.data_path, f'{kpi_dic_name}.json'), "r") as fp: |
| kpi_dic = json.load(fp) |
| kpi_dic = kpi_lower_update(kpi_dic) |
| |
| random.shuffle(seq_data) |
| |
| print(f"tokenizer size before: {len(tokenizer)}") |
| tokenizer, special_token, norm_token = add_special_token(tokenizer) |
| special_token = special_token + norm_token |
|
|
| print(f"tokenizer size after: {len(tokenizer)}") |
| print('------------------------ refresh data --------------------------------') |
| add_words = refresh_data(ref, cfgs.freq, special_token) |
|
|
| if not cfgs.read_cws: |
| print('------------------------ cws ----------------------------------') |
| seq_data_cws, data_size, error_data = cws(seq_data, add_words, cfgs.batch_size) |
| print(f'batch size is {cfgs.batch_size}') |
| if len(error_data) > 0: |
| with open(osp.join(cfgs.data_path, f'{seq_data_name}_error.json'), "w") as fp: |
| json.dump(error_data, fp, ensure_ascii=False) |
| save_path_cws_orig = osp.join(cfgs.data_path, f'{seq_data_name}_cws_orig.json') |
| print("get the new training data! saving...") |
| with open(save_path_cws_orig, 'w', ) as fp: |
| json.dump(seq_data_cws, fp, ensure_ascii=False) |
| else: |
| print('------------------------ read ----------------------------------') |
| save_path_cws = osp.join(cfgs.data_path, f'{seq_data_name}_cws_orig.json') |
| print("get the new training data!") |
| with open(save_path_cws, 'r', ) as fp: |
| seq_data_cws = json.load(fp) |
| data_size = len(seq_data_cws) |
|
|
| sz_orig = len(seq_data_cws) |
| if cfgs.deal_numeric: |
| seq_data_cws, num_ref = extract_num(seq_data_cws) |
| print(f"过滤了{sz_orig - len(seq_data_cws)}个无效句子") |
| data_size = len(seq_data_cws) |
|
|
| print('---------------------- generate chinese ref ------------------------------') |
| chinese_ref, kpi_info, sten_that_over_maxl = generate_chinese_ref(seq_data_cws, special_token, cfgs.deal_numeric, kpi_dic) |
|
|
| if len(sten_that_over_maxl) > 0: |
| print(f"{len(sten_that_over_maxl)} over the 500 len!") |
| save_path_max = osp.join(cfgs.data_path, f'{seq_data_name}_max_len_500.json') |
| with open(save_path_max, 'w') as fp: |
| json.dump(sten_that_over_maxl, fp, ensure_ascii=False) |
|
|
| if cfgs.deal_numeric: |
| print("KPI info combine") |
| kpi_ref = kpi_combine(kpi_info, num_ref) |
| |
| print('------------------------- match finished ------------------------------') |
|
|
| |
| save_path_ref = osp.join(cfgs.data_path, f'{seq_data_name}_chinese_ref.json') |
| with open(save_path_ref, 'w') as fp: |
| json.dump(chinese_ref, fp, ensure_ascii=False) |
| print(f"save chinese_ref done!") |
|
|
| seq_data_cws_output = [] |
| for i in range(data_size): |
| seq = " ".join(seq_data_cws[i]) |
| seq_data_cws_output.append(seq) |
|
|
| save_path_cws = osp.join(cfgs.data_path, f'{seq_data_name}_cws.json') |
| print("get the new training data!") |
| with open(save_path_cws, 'w', ) as fp: |
| json.dump(seq_data_cws_output, fp, ensure_ascii=False) |
|
|
| print("save seq_data_cws done!") |
|
|
| if cfgs.deal_numeric: |
| kpi_ref_path = osp.join(cfgs.data_path, f'{seq_data_name}_kpi_ref.json') |
| with open(kpi_ref_path, 'w', ) as fp: |
| json.dump(kpi_ref, fp, ensure_ascii=False) |
| print("save num and kpi done!") |
|
|