File size: 9,717 Bytes
4526d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# -*- coding: utf-8 -*-
import gradio as gr
import re
import logging
from datetime import datetime
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List

# ==================== 日志配置 ====================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('text_correction.log', encoding='utf-8'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ==================== 加载模型 ====================
logger.info("正在加载模型,请稍候...")
model_name = "twnlp/ChineseErrorCorrector3-4B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,      # 内存减半,现代 CPU 均支持
    device_map="cpu",
    low_cpu_mem_usage=True,           # 加载时减少峰值内存占用
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info("模型加载完成 ✓")

# ==================== 段落分割 ====================
blanks = ["\ufeff", "\u3000", "\u2002", "\xa0", "\x07", "\x0b", "\x0c", "_", "_", "\u200d", "\u200c"]

def replace_blanks(text):
    for blank in blanks:
        text = text.replace(blank, " ")
    return text

def split_sentence(document_input: str, min_len: int = 16, max_len: int = 126):
    sent_list = []
    try:
        punctuation_flag = re.search(
            r"""[^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}""",
            document_input
        )

        if punctuation_flag:
            document = re.sub(
                r"""(?P<quotation_mark>([^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}))""",
                r'\g<quotation_mark>\n', document_input
            )
        else:
            document = re.sub(
                r"""(?P<quotation_mark>([。?!…?!|](?!["'"\'])))""",
                r'\g<quotation_mark>\n', document_input
            )
            document = re.sub(
                r"""(?P<quotation_mark>(([。?!!?|]|…{1,2})["'"\']))""",
                r'\g<quotation_mark>\n', document
            )

        sent_list_ori = document.split('\n')
        for sent in sent_list_ori:
            sent = sent.replace('|', '')
            if not sent:
                continue
            if len(sent) > max_len:
                sent_list.extend(split_subsentence(sent, min_len=min_len))
            else:
                sent_list.append(sent)
    except:
        sent_list.clear()
        sent_list.append(document_input)

    assert sum(len(s) for s in sent_list) == len(document_input)
    p = 0
    res = []
    for sent in sent_list:
        res.append([p, sent])
        p += len(sent)
    return res

sub_split_flag = [',', ',', ';', ';', ')', ')']

def split_subsentence(sentence, min_len=16):
    sent = ''
    for i, c in enumerate(sentence):
        sent += c
        if c in sub_split_flag:
            if i == len(sentence) - 2:
                yield sent[:-1] + c + sentence[-1]
                break
            flag = True
            for j in range(i + 1, min(len(sentence) - 1, i + 6)):
                if sentence[j] == ',' or j == len(sentence) - 1:
                    flag = False
            if (flag and len(sent) >= min_len) or i == len(sentence) - 1:
                yield sent[:-1] + c
                sent = ''
        elif i == len(sentence) - 1:
            yield sent

def split_paragraph_lst(paragraph_lst: List[str], min_len: int = 16, max_len: int = 126):
    preprocessed = []
    for s in paragraph_lst:
        s = replace_blanks(s)
        s = s.replace('\r', '').split('\n')
        for s_ in s:
            s_ = s_.split('|')
            preprocessed.extend(s_)
    paragraph_lst = preprocessed

    p = 0
    offset_lst = []
    for s in paragraph_lst:
        offset_lst.append(p)
        p += len(s)

    res = []
    for offset_sent, sent in zip(offset_lst, paragraph_lst):
        sent = sent.replace('|', '')
        if not sent.strip():
            continue
        if len(sent) > max_len:
            for offset_subsent, subsent in split_sentence(sent, min_len=min_len, max_len=max_len):
                if not subsent.strip():
                    continue
                res.append([offset_sent + offset_subsent, subsent])
        else:
            res.append([offset_sent, sent])
    return res

# ==================== 纠错核心 ====================
def clean_model_output(text):
    text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
    return text.strip()

def find_diff_segments(source, target):
    if source == target:
        return []
    n, m = len(source), len(target)
    prefix_len = 0
    while prefix_len < min(n, m) and source[prefix_len] == target[prefix_len]:
        prefix_len += 1
    suffix_len = 0
    while suffix_len < min(n - prefix_len, m - prefix_len) and \
          source[n - 1 - suffix_len] == target[m - 1 - suffix_len]:
        suffix_len += 1
    src_diff = source[prefix_len:n - suffix_len] if n - suffix_len > prefix_len else ""
    tgt_diff = target[prefix_len:m - suffix_len] if m - suffix_len > prefix_len else ""
    if not src_diff and not tgt_diff:
        return []
    return [{
        "original": src_diff,
        "corrected": tgt_diff,
        "position": prefix_len,
        "type": "replace" if src_diff and tgt_diff else ("delete" if src_diff else "insert")
    }]

def correct_single_sentence(sentence: str) -> str:
    """对单个句子调用模型纠错,返回纠正后的文本"""
    prompt = "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子,输入句子为:"
    messages = [{"role": "user", "content": prompt + sentence}]
    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
    generated_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    raw_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return clean_model_output(raw_output)

def text_correction(input_text):
    logger.info("=" * 60)
    logger.info(f"[用户输入] {input_text}")

    if not input_text.strip():
        return "请输入需要纠错的文本", ""

    try:
        start_time = datetime.now()

        # 分割段落为子句
        segments = split_paragraph_lst([input_text])
        logger.info(f"[分句结果] 共 {len(segments)} 个子句")

        all_errors = {}
        corrected_parts = []
        error_count = 0

        for offset, sent in segments:
            logger.info(f"  [子句] offset={offset} | {sent}")
            corrected = correct_single_sentence(sent)
            logger.info(f"  [纠正] {corrected}")
            corrected_parts.append(corrected)

            # 收集差异
            diffs = find_diff_segments(sent, corrected)
            for diff in diffs:
                error_count += 1
                diff["position"] = offset + diff["position"]  # 映射回原文位置
                all_errors[f"error_{error_count}"] = diff

        corrected_full = "".join(corrected_parts)
        duration = (datetime.now() - start_time).total_seconds()
        logger.info(f"[总耗时] {duration:.2f} 秒")

        result = {"tgt": corrected_full, "des": all_errors}
        result_json = json.dumps(result, ensure_ascii=False, indent=2)

        if all_errors:
            error_details = "**发现的错误:**\n\n"
            for key, error in all_errors.items():
                error_details += f"- 位置 {error['position']}: `{error['original']}` → `{error['corrected']}`\n"
        else:
            error_details = "✅ 未发现错误,句子正确!"

        output_text = f"**原文:**\n{input_text}\n\n**纠正后:**\n{corrected_full}\n\n{error_details}"
        logger.info("[处理完成] ✓")

        return output_text, result_json

    except Exception as e:
        logger.error(f"[错误] {str(e)}", exc_info=True)
        return f"错误: {str(e)}", ""

# ==================== Gradio 界面 ====================
with gr.Blocks(title="ChineseErrorCorrector3") as demo:
    gr.Markdown("# 🔍 ChineseErrorCorrector3")
    gr.Markdown("支持长段落输入,自动分句后逐句纠错(本地 CPU 推理,句子越多耗时越长)")

    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                label="输入文本(支持长段落)",
                placeholder="例如:他每天都去跑部锻炼身体。对待每一项工作都要一丝不够。",
                lines=5
            )
            submit_btn = gr.Button("开始纠错", variant="primary")
        with gr.Column():
            output_display = gr.Markdown(label="纠错结果")

    with gr.Row():
        result_json = gr.Textbox(label="JSON 格式输出", lines=10, interactive=False)

    gr.Examples(
        examples=[
            ["我的名字较做小明"],
            ["他每天都去跑部锻炼身体"]
        ],
        inputs=input_text
    )

    submit_btn.click(fn=text_correction, inputs=input_text, outputs=[output_display, result_json])
    input_text.submit(fn=text_correction, inputs=input_text, outputs=[output_display, result_json])

if __name__ == "__main__":
    logger.info("启动中文文本纠错助手...")
    demo.launch()