twnlp's picture
Upload 2 files
4526d38 verified
# -*- coding: utf-8 -*-
import gradio as gr
import re
import logging
from datetime import datetime
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List
# ==================== 日志配置 ====================
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('text_correction.log', encoding='utf-8'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# ==================== 加载模型 ====================
logger.info("正在加载模型,请稍候...")
model_name = "twnlp/ChineseErrorCorrector3-4B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16, # 内存减半,现代 CPU 均支持
device_map="cpu",
low_cpu_mem_usage=True, # 加载时减少峰值内存占用
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info("模型加载完成 ✓")
# ==================== 段落分割 ====================
blanks = ["\ufeff", "\u3000", "\u2002", "\xa0", "\x07", "\x0b", "\x0c", "_", "_", "\u200d", "\u200c"]
def replace_blanks(text):
for blank in blanks:
text = text.replace(blank, " ")
return text
def split_sentence(document_input: str, min_len: int = 16, max_len: int = 126):
sent_list = []
try:
punctuation_flag = re.search(
r"""[^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}""",
document_input
)
if punctuation_flag:
document = re.sub(
r"""(?P<quotation_mark>([^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}))""",
r'\g<quotation_mark>\n', document_input
)
else:
document = re.sub(
r"""(?P<quotation_mark>([。?!…?!|](?!["'"\'])))""",
r'\g<quotation_mark>\n', document_input
)
document = re.sub(
r"""(?P<quotation_mark>(([。?!!?|]|…{1,2})["'"\']))""",
r'\g<quotation_mark>\n', document
)
sent_list_ori = document.split('\n')
for sent in sent_list_ori:
sent = sent.replace('|', '')
if not sent:
continue
if len(sent) > max_len:
sent_list.extend(split_subsentence(sent, min_len=min_len))
else:
sent_list.append(sent)
except:
sent_list.clear()
sent_list.append(document_input)
assert sum(len(s) for s in sent_list) == len(document_input)
p = 0
res = []
for sent in sent_list:
res.append([p, sent])
p += len(sent)
return res
sub_split_flag = [',', ',', ';', ';', ')', ')']
def split_subsentence(sentence, min_len=16):
sent = ''
for i, c in enumerate(sentence):
sent += c
if c in sub_split_flag:
if i == len(sentence) - 2:
yield sent[:-1] + c + sentence[-1]
break
flag = True
for j in range(i + 1, min(len(sentence) - 1, i + 6)):
if sentence[j] == ',' or j == len(sentence) - 1:
flag = False
if (flag and len(sent) >= min_len) or i == len(sentence) - 1:
yield sent[:-1] + c
sent = ''
elif i == len(sentence) - 1:
yield sent
def split_paragraph_lst(paragraph_lst: List[str], min_len: int = 16, max_len: int = 126):
preprocessed = []
for s in paragraph_lst:
s = replace_blanks(s)
s = s.replace('\r', '').split('\n')
for s_ in s:
s_ = s_.split('|')
preprocessed.extend(s_)
paragraph_lst = preprocessed
p = 0
offset_lst = []
for s in paragraph_lst:
offset_lst.append(p)
p += len(s)
res = []
for offset_sent, sent in zip(offset_lst, paragraph_lst):
sent = sent.replace('|', '')
if not sent.strip():
continue
if len(sent) > max_len:
for offset_subsent, subsent in split_sentence(sent, min_len=min_len, max_len=max_len):
if not subsent.strip():
continue
res.append([offset_sent + offset_subsent, subsent])
else:
res.append([offset_sent, sent])
return res
# ==================== 纠错核心 ====================
def clean_model_output(text):
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
return text.strip()
def find_diff_segments(source, target):
if source == target:
return []
n, m = len(source), len(target)
prefix_len = 0
while prefix_len < min(n, m) and source[prefix_len] == target[prefix_len]:
prefix_len += 1
suffix_len = 0
while suffix_len < min(n - prefix_len, m - prefix_len) and \
source[n - 1 - suffix_len] == target[m - 1 - suffix_len]:
suffix_len += 1
src_diff = source[prefix_len:n - suffix_len] if n - suffix_len > prefix_len else ""
tgt_diff = target[prefix_len:m - suffix_len] if m - suffix_len > prefix_len else ""
if not src_diff and not tgt_diff:
return []
return [{
"original": src_diff,
"corrected": tgt_diff,
"position": prefix_len,
"type": "replace" if src_diff and tgt_diff else ("delete" if src_diff else "insert")
}]
def correct_single_sentence(sentence: str) -> str:
"""对单个句子调用模型纠错,返回纠正后的文本"""
prompt = "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子,输入句子为:"
messages = [{"role": "user", "content": prompt + sentence}]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
raw_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return clean_model_output(raw_output)
def text_correction(input_text):
logger.info("=" * 60)
logger.info(f"[用户输入] {input_text}")
if not input_text.strip():
return "请输入需要纠错的文本", ""
try:
start_time = datetime.now()
# 分割段落为子句
segments = split_paragraph_lst([input_text])
logger.info(f"[分句结果] 共 {len(segments)} 个子句")
all_errors = {}
corrected_parts = []
error_count = 0
for offset, sent in segments:
logger.info(f" [子句] offset={offset} | {sent}")
corrected = correct_single_sentence(sent)
logger.info(f" [纠正] {corrected}")
corrected_parts.append(corrected)
# 收集差异
diffs = find_diff_segments(sent, corrected)
for diff in diffs:
error_count += 1
diff["position"] = offset + diff["position"] # 映射回原文位置
all_errors[f"error_{error_count}"] = diff
corrected_full = "".join(corrected_parts)
duration = (datetime.now() - start_time).total_seconds()
logger.info(f"[总耗时] {duration:.2f} 秒")
result = {"tgt": corrected_full, "des": all_errors}
result_json = json.dumps(result, ensure_ascii=False, indent=2)
if all_errors:
error_details = "**发现的错误:**\n\n"
for key, error in all_errors.items():
error_details += f"- 位置 {error['position']}: `{error['original']}` → `{error['corrected']}`\n"
else:
error_details = "✅ 未发现错误,句子正确!"
output_text = f"**原文:**\n{input_text}\n\n**纠正后:**\n{corrected_full}\n\n{error_details}"
logger.info("[处理完成] ✓")
return output_text, result_json
except Exception as e:
logger.error(f"[错误] {str(e)}", exc_info=True)
return f"错误: {str(e)}", ""
# ==================== Gradio 界面 ====================
with gr.Blocks(title="ChineseErrorCorrector3") as demo:
gr.Markdown("# 🔍 ChineseErrorCorrector3")
gr.Markdown("支持长段落输入,自动分句后逐句纠错(本地 CPU 推理,句子越多耗时越长)")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="输入文本(支持长段落)",
placeholder="例如:他每天都去跑部锻炼身体。对待每一项工作都要一丝不够。",
lines=5
)
submit_btn = gr.Button("开始纠错", variant="primary")
with gr.Column():
output_display = gr.Markdown(label="纠错结果")
with gr.Row():
result_json = gr.Textbox(label="JSON 格式输出", lines=10, interactive=False)
gr.Examples(
examples=[
["我的名字较做小明"],
["他每天都去跑部锻炼身体"]
],
inputs=input_text
)
submit_btn.click(fn=text_correction, inputs=input_text, outputs=[output_display, result_json])
input_text.submit(fn=text_correction, inputs=input_text, outputs=[output_display, result_json])
if __name__ == "__main__":
logger.info("启动中文文本纠错助手...")
demo.launch()