Macropodus commited on
Commit
b897bdf
·
verified ·
1 Parent(s): f13d00d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -73
app.py CHANGED
@@ -6,100 +6,124 @@
6
 
7
 
8
  import traceback
 
9
  import time
10
  import sys
11
  import os
 
 
 
12
  os.environ["USE_TORCH"] = "1"
13
 
14
- from transformers import BertConfig, BertTokenizer, BertForMaskedLM
 
 
 
 
 
 
15
  import gradio as gr
16
- import torch
17
-
18
 
 
19
 
20
  # pretrained_model_name_or_path = "shibing624/macbert4csc-base-chinese"
 
21
  # pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v1"
22
  # pretrained_model_name_or_path = "Macropodus/macbert4csc_v1"
23
  pretrained_model_name_or_path = "Macropodus/macbert4csc_v2"
24
  # pretrained_model_name_or_path = "Macropodus/bert4csc_v1"
25
- device = torch.device("cpu")
26
  # device = torch.device("cuda")
27
- max_len = 128
28
-
29
- print("load model, please wait a few minute!")
30
- tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
31
- bert_config = BertConfig.from_pretrained(pretrained_model_name_or_path)
32
- model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
33
- model.to(device)
34
- print("load model success!")
35
-
36
- texts = [
37
- "机七学习是人工智能领遇最能体现智能的一个分知",
38
- "我是练习时长两念半的鸽仁练习生蔡徐坤",
39
- ]
40
- len_mid = min(max_len, max([len(t)+2 for t in texts]))
41
-
42
- with torch.no_grad():
43
- outputs = model(**tokenizer(texts, padding=True, max_length=len_mid,
44
- return_tensors="pt").to(device))
45
-
46
- def get_errors(source, target):
47
- """ 极简方法获取 errors """
48
- len_min = min(len(source), len(target))
49
- errors = []
50
- for idx in range(len_min):
51
- if source[idx] != target[idx]:
52
- errors.append([source[idx], target[idx], idx])
53
- return errors
54
-
55
- result = []
56
- for probs, source in zip(outputs.logits, texts):
57
- ids = torch.argmax(probs, dim=-1)
58
- tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False)
59
- text_new = tokens_space.replace(" ", "")
60
- target = text_new[:len(source)]
61
- errors = get_errors(source, target)
62
- print(source, " => ", target, errors)
63
- result.append([target, errors])
64
- print(result)
 
 
 
 
 
 
65
 
66
 
67
  def macro_correct(text):
68
- with torch.no_grad():
69
- outputs = model(**tokenizer([text], padding=True, max_length=max_len,
70
- return_tensors="pt").to(device))
71
-
72
- def to_highlight(corrected_sent, errs):
73
- output = [{"entity": "纠错", "word": err[1], "start": err[2], "end": err[3]} for i, err in
74
- enumerate(errs)]
75
- return {"text": corrected_sent, "entities": output}
76
-
77
- def get_errors(source, target):
78
- """ 极简方法获取 errors """
79
- len_min = min(len(source), len(target))
80
- errors = []
81
- for idx in range(len_min):
82
- if source[idx] != target[idx]:
83
- errors.append([source[idx], target[idx], idx])
84
- return errors
85
-
86
- result = []
87
- for probs, source in zip(outputs.logits, [text]):
88
- ids = torch.argmax(probs, dim=-1)
89
- tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False)
90
- text_new = tokens_space.replace(" ", "")
91
- target = text_new[:len(source)]
92
- errors = get_errors(source, target)
93
- print(source, " => ", target, errors)
94
- result.append([target, errors])
95
- # print(result)
96
- return target + " " + str(errors)
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  if __name__ == '__main__':
100
  print(macro_correct('少先队员因该为老人让坐'))
101
 
102
-
103
  examples = [
104
  "机七学习是人工智能领遇最能体现智能的一个分知",
105
  "我是练习时长两念半的鸽仁练习生蔡徐坤",
@@ -118,5 +142,7 @@ if __name__ == '__main__':
118
  examples=examples
119
  ).launch()
120
  # ).launch(server_name="0.0.0.0", server_port=8066, share=False, debug=True)
121
-
 
 
122
 
 
6
 
7
 
8
  import traceback
9
+ import copy
10
  import time
11
  import sys
12
  import os
13
+ import re
14
+ os.environ["MACRO_CORRECT_FLAG_CSC_TOKEN"] = "1"
15
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
16
  os.environ["USE_TORCH"] = "1"
17
 
18
+ from macro_correct.pytorch_textcorrection.tcTools import preprocess_same_with_training
19
+ from macro_correct.pytorch_textcorrection.tcTools import get_errors_for_difflib
20
+ from macro_correct.pytorch_textcorrection.tcTools import cut_sent_by_maxlen
21
+ from macro_correct.pytorch_textcorrection.tcTools import count_flag_zh
22
+ from macro_correct import correct_basic
23
+ from macro_correct import correct_long
24
+ from macro_correct import correct
25
  import gradio as gr
 
 
26
 
27
+ # pyinstaller -F xxxx.py
28
 
29
  # pretrained_model_name_or_path = "shibing624/macbert4csc-base-chinese"
30
+ # pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
31
  # pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v1"
32
  # pretrained_model_name_or_path = "Macropodus/macbert4csc_v1"
33
  pretrained_model_name_or_path = "Macropodus/macbert4csc_v2"
34
  # pretrained_model_name_or_path = "Macropodus/bert4csc_v1"
35
+ # device = torch.device("cpu")
36
  # device = torch.device("cuda")
37
+
38
+
39
+ def cut_sent_by_stay_and_maxlen(text, max_len=126, return_length=True):
40
+ """
41
+ 分句但是保存原标点符号, 如果长度还是太长的话就切为固定长度的句子
42
+ Args:
43
+ text: str, sentence of input text;
44
+ max_len: int, max_len of traing texts;
45
+ return_length: bool, wether return length or not
46
+ Returns:
47
+ res: List<tuple>
48
+ """
49
+ ### text_sp = re.split(r"!”|?”|。”|……”|”!|”?|”。|”……|》。|)。|!|?|。|…|\!|\?", text)
50
+ text_sp = re.split(r"[》)!?。…”;;!?\n]+", text)
51
+ conn_symbol = "!?。…”;;!?》)\n"
52
+ text_length_s = []
53
+ text_cut = []
54
+ len_text = len(text) - 1
55
+ # signal_symbol = "—”>;?…)‘《’(·》“~,、!。:<"
56
+ len_global = 0
57
+ for idx, text_sp_i in enumerate(text_sp):
58
+ text_cut_idx = text_sp[idx]
59
+ len_global_before = copy.deepcopy(len_global)
60
+ len_global += len(text_sp_i)
61
+ while True:
62
+ if len_global <= len_text and text[len_global] in conn_symbol:
63
+ text_cut_idx += text[len_global]
64
+ else:
65
+ # len_global += 1
66
+ if text_cut_idx:
67
+ ### 如果标点符号依旧切分不了, 就强行切
68
+ if len(text_cut_idx) > max_len:
69
+ text_cut_i, text_length_s_i = cut_sent_by_maxlen(
70
+ text=text, max_len=max_len, return_length=True)
71
+ text_length_s.extend(text_length_s_i)
72
+ text_cut.extend(text_cut_i)
73
+ else:
74
+ text_length_s.append([len_global_before, len_global])
75
+ text_cut.append(text_cut_idx)
76
+ break
77
+ len_global += 1
78
+ if return_length:
79
+ return text_cut, text_length_s
80
+ return text_cut
81
 
82
 
83
  def macro_correct(text):
84
+ print(text)
85
+ texts, texts_length = cut_sent_by_stay_and_maxlen(text, return_length=True)
86
+ text_str = ""
87
+ text_list = []
88
+ for t in texts:
89
+ print(t)
90
+ t_process = preprocess_same_with_training(t)
91
+ text_csc = correct_long(t_process, num_rethink=1, flag_cut=True, limit_length_char=1)
92
+ print(text_csc)
93
+ ### 繁简
94
+ if t != t_process:
95
+ t_correct, errors = get_errors_for_difflib(t_process, t)
96
+ errors_new = []
97
+ for err in errors:
98
+ if count_flag_zh(err[0]) and count_flag_zh(err[1]):
99
+ errors_new.append(err + [1])
100
+ if errors_new:
101
+ if text_csc:
102
+ text_csc[0]["errors"] += errors_new
103
+ else:
104
+ text_csc = [{"source": t, "target": t_process, "errors": errors_new}]
105
+ ### 本身的错误
106
+ if text_csc:
107
+ text_list.extend(text_csc)
108
+ text_str += text_csc[0].get("target")
109
+ else:
110
+ text_list.extend([{}])
111
+ text_str += t
112
+ text_str += "\n" + "#" * 32 + "\n"
113
+ for tdx, t in enumerate(text_list):
114
+ if t:
115
+ for tk, tv in t.items():
116
+ if tk == "index":
117
+ text_str += f"idx: {str(tdx+1)}\n"
118
+ else:
119
+ text_str += f"{str(tk).strip()}: {str(tv).strip()}\n"
120
+ text_str += "\n"
121
+ return text_str
122
 
123
 
124
  if __name__ == '__main__':
125
  print(macro_correct('少先队员因该为老人让坐'))
126
 
 
127
  examples = [
128
  "机七学习是人工智能领遇最能体现智能的一个分知",
129
  "我是练习时长两念半的鸽仁练习生蔡徐坤",
 
142
  examples=examples
143
  ).launch()
144
  # ).launch(server_name="0.0.0.0", server_port=8066, share=False, debug=True)
145
+
146
+
147
+
148