twnlp commited on
Commit
4526d38
·
verified ·
1 Parent(s): 06b578f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +265 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import gradio as gr
3
+ import re
4
+ import logging
5
+ from datetime import datetime
6
+ import json
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from typing import List
10
+
11
+ # ==================== 日志配置 ====================
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(levelname)s - %(message)s',
15
+ handlers=[
16
+ logging.FileHandler('text_correction.log', encoding='utf-8'),
17
+ logging.StreamHandler()
18
+ ]
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # ==================== 加载模型 ====================
23
+ logger.info("正在加载模型,请稍候...")
24
+ model_name = "twnlp/ChineseErrorCorrector3-4B"
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.bfloat16, # 内存减半,现代 CPU 均支持
28
+ device_map="cpu",
29
+ low_cpu_mem_usage=True, # 加载时减少峰值内存占用
30
+ )
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+ logger.info("模型加载完成 ✓")
33
+
34
+ # ==================== 段落分割 ====================
35
+ blanks = ["\ufeff", "\u3000", "\u2002", "\xa0", "\x07", "\x0b", "\x0c", "_", "_", "\u200d", "\u200c"]
36
+
37
+ def replace_blanks(text):
38
+ for blank in blanks:
39
+ text = text.replace(blank, " ")
40
+ return text
41
+
42
+ def split_sentence(document_input: str, min_len: int = 16, max_len: int = 126):
43
+ sent_list = []
44
+ try:
45
+ punctuation_flag = re.search(
46
+ r"""[^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}""",
47
+ document_input
48
+ )
49
+
50
+ if punctuation_flag:
51
+ document = re.sub(
52
+ r"""(?P<quotation_mark>([^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}))""",
53
+ r'\g<quotation_mark>\n', document_input
54
+ )
55
+ else:
56
+ document = re.sub(
57
+ r"""(?P<quotation_mark>([。?!…?!|](?!["'"\'])))""",
58
+ r'\g<quotation_mark>\n', document_input
59
+ )
60
+ document = re.sub(
61
+ r"""(?P<quotation_mark>(([。?!!?|]|…{1,2})["'"\']))""",
62
+ r'\g<quotation_mark>\n', document
63
+ )
64
+
65
+ sent_list_ori = document.split('\n')
66
+ for sent in sent_list_ori:
67
+ sent = sent.replace('|', '')
68
+ if not sent:
69
+ continue
70
+ if len(sent) > max_len:
71
+ sent_list.extend(split_subsentence(sent, min_len=min_len))
72
+ else:
73
+ sent_list.append(sent)
74
+ except:
75
+ sent_list.clear()
76
+ sent_list.append(document_input)
77
+
78
+ assert sum(len(s) for s in sent_list) == len(document_input)
79
+ p = 0
80
+ res = []
81
+ for sent in sent_list:
82
+ res.append([p, sent])
83
+ p += len(sent)
84
+ return res
85
+
86
+ sub_split_flag = [',', ',', ';', ';', ')', ')']
87
+
88
+ def split_subsentence(sentence, min_len=16):
89
+ sent = ''
90
+ for i, c in enumerate(sentence):
91
+ sent += c
92
+ if c in sub_split_flag:
93
+ if i == len(sentence) - 2:
94
+ yield sent[:-1] + c + sentence[-1]
95
+ break
96
+ flag = True
97
+ for j in range(i + 1, min(len(sentence) - 1, i + 6)):
98
+ if sentence[j] == ',' or j == len(sentence) - 1:
99
+ flag = False
100
+ if (flag and len(sent) >= min_len) or i == len(sentence) - 1:
101
+ yield sent[:-1] + c
102
+ sent = ''
103
+ elif i == len(sentence) - 1:
104
+ yield sent
105
+
106
+ def split_paragraph_lst(paragraph_lst: List[str], min_len: int = 16, max_len: int = 126):
107
+ preprocessed = []
108
+ for s in paragraph_lst:
109
+ s = replace_blanks(s)
110
+ s = s.replace('\r', '').split('\n')
111
+ for s_ in s:
112
+ s_ = s_.split('|')
113
+ preprocessed.extend(s_)
114
+ paragraph_lst = preprocessed
115
+
116
+ p = 0
117
+ offset_lst = []
118
+ for s in paragraph_lst:
119
+ offset_lst.append(p)
120
+ p += len(s)
121
+
122
+ res = []
123
+ for offset_sent, sent in zip(offset_lst, paragraph_lst):
124
+ sent = sent.replace('|', '')
125
+ if not sent.strip():
126
+ continue
127
+ if len(sent) > max_len:
128
+ for offset_subsent, subsent in split_sentence(sent, min_len=min_len, max_len=max_len):
129
+ if not subsent.strip():
130
+ continue
131
+ res.append([offset_sent + offset_subsent, subsent])
132
+ else:
133
+ res.append([offset_sent, sent])
134
+ return res
135
+
136
+ # ==================== 纠错核心 ====================
137
+ def clean_model_output(text):
138
+ text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
139
+ return text.strip()
140
+
141
+ def find_diff_segments(source, target):
142
+ if source == target:
143
+ return []
144
+ n, m = len(source), len(target)
145
+ prefix_len = 0
146
+ while prefix_len < min(n, m) and source[prefix_len] == target[prefix_len]:
147
+ prefix_len += 1
148
+ suffix_len = 0
149
+ while suffix_len < min(n - prefix_len, m - prefix_len) and \
150
+ source[n - 1 - suffix_len] == target[m - 1 - suffix_len]:
151
+ suffix_len += 1
152
+ src_diff = source[prefix_len:n - suffix_len] if n - suffix_len > prefix_len else ""
153
+ tgt_diff = target[prefix_len:m - suffix_len] if m - suffix_len > prefix_len else ""
154
+ if not src_diff and not tgt_diff:
155
+ return []
156
+ return [{
157
+ "original": src_diff,
158
+ "corrected": tgt_diff,
159
+ "position": prefix_len,
160
+ "type": "replace" if src_diff and tgt_diff else ("delete" if src_diff else "insert")
161
+ }]
162
+
163
+ def correct_single_sentence(sentence: str) -> str:
164
+ """对单个句子调用模型纠错,返回纠正后的文本"""
165
+ prompt = "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子,输入句子为:"
166
+ messages = [{"role": "user", "content": prompt + sentence}]
167
+ text = tokenizer.apply_chat_template(
168
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
169
+ )
170
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
171
+ generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
172
+ generated_ids = [
173
+ output_ids[len(input_ids):]
174
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
175
+ ]
176
+ raw_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
177
+ return clean_model_output(raw_output)
178
+
179
+ def text_correction(input_text):
180
+ logger.info("=" * 60)
181
+ logger.info(f"[用户输入] {input_text}")
182
+
183
+ if not input_text.strip():
184
+ return "请输入需要纠错的文本", ""
185
+
186
+ try:
187
+ start_time = datetime.now()
188
+
189
+ # 分割段落为子句
190
+ segments = split_paragraph_lst([input_text])
191
+ logger.info(f"[分句结果] 共 {len(segments)} 个子句")
192
+
193
+ all_errors = {}
194
+ corrected_parts = []
195
+ error_count = 0
196
+
197
+ for offset, sent in segments:
198
+ logger.info(f" [子句] offset={offset} | {sent}")
199
+ corrected = correct_single_sentence(sent)
200
+ logger.info(f" [纠正] {corrected}")
201
+ corrected_parts.append(corrected)
202
+
203
+ # 收集差异
204
+ diffs = find_diff_segments(sent, corrected)
205
+ for diff in diffs:
206
+ error_count += 1
207
+ diff["position"] = offset + diff["position"] # 映射回原文位置
208
+ all_errors[f"error_{error_count}"] = diff
209
+
210
+ corrected_full = "".join(corrected_parts)
211
+ duration = (datetime.now() - start_time).total_seconds()
212
+ logger.info(f"[总耗时] {duration:.2f} 秒")
213
+
214
+ result = {"tgt": corrected_full, "des": all_errors}
215
+ result_json = json.dumps(result, ensure_ascii=False, indent=2)
216
+
217
+ if all_errors:
218
+ error_details = "**发现的错误:**\n\n"
219
+ for key, error in all_errors.items():
220
+ error_details += f"- 位置 {error['position']}: `{error['original']}` → `{error['corrected']}`\n"
221
+ else:
222
+ error_details = "✅ 未发现错误,句子正确!"
223
+
224
+ output_text = f"**原文:**\n{input_text}\n\n**纠正后:**\n{corrected_full}\n\n{error_details}"
225
+ logger.info("[处理完成] ✓")
226
+
227
+ return output_text, result_json
228
+
229
+ except Exception as e:
230
+ logger.error(f"[错误] {str(e)}", exc_info=True)
231
+ return f"错误: {str(e)}", ""
232
+
233
+ # ==================== Gradio 界面 ====================
234
+ with gr.Blocks(title="ChineseErrorCorrector3") as demo:
235
+ gr.Markdown("# 🔍 ChineseErrorCorrector3")
236
+ gr.Markdown("支持长段落输入,自动分句后逐句纠错(本地 CPU 推理,句子越多耗时越长)")
237
+
238
+ with gr.Row():
239
+ with gr.Column():
240
+ input_text = gr.Textbox(
241
+ label="输入文本(支持长段落)",
242
+ placeholder="例如:他每天都去跑部锻炼身体。对待每一项工作都要一丝不够。",
243
+ lines=5
244
+ )
245
+ submit_btn = gr.Button("开始纠错", variant="primary")
246
+ with gr.Column():
247
+ output_display = gr.Markdown(label="纠错结果")
248
+
249
+ with gr.Row():
250
+ result_json = gr.Textbox(label="JSON 格式输出", lines=10, interactive=False)
251
+
252
+ gr.Examples(
253
+ examples=[
254
+ ["我的名字较做小明"],
255
+ ["他每天都去跑部锻炼身体"]
256
+ ],
257
+ inputs=input_text
258
+ )
259
+
260
+ submit_btn.click(fn=text_correction, inputs=input_text, outputs=[output_display, result_json])
261
+ input_text.submit(fn=text_correction, inputs=input_text, outputs=[output_display, result_json])
262
+
263
+ if __name__ == "__main__":
264
+ logger.info("启动中文文本纠错助手...")
265
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ openai
3
+ transformers