# -*- coding: utf-8 -*- import gradio as gr import re import logging from datetime import datetime import json import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List # ==================== 日志配置 ==================== logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('text_correction.log', encoding='utf-8'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # ==================== 加载模型 ==================== logger.info("正在加载模型,请稍候...") model_name = "twnlp/ChineseErrorCorrector3-4B" model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, # 内存减半,现代 CPU 均支持 device_map="cpu", low_cpu_mem_usage=True, # 加载时减少峰值内存占用 ) tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info("模型加载完成 ✓") # ==================== 段落分割 ==================== blanks = ["\ufeff", "\u3000", "\u2002", "\xa0", "\x07", "\x0b", "\x0c", "_", "_", "\u200d", "\u200c"] def replace_blanks(text): for blank in blanks: text = text.replace(blank, " ") return text def split_sentence(document_input: str, min_len: int = 16, max_len: int = 126): sent_list = [] try: punctuation_flag = re.search( r"""[^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}""", document_input ) if punctuation_flag: document = re.sub( r"""(?P([^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}))""", r'\g\n', document_input ) else: document = re.sub( r"""(?P([。?!…?!|](?!["'"\'])))""", r'\g\n', document_input ) document = re.sub( r"""(?P(([。?!!?|]|…{1,2})["'"\']))""", r'\g\n', document ) sent_list_ori = document.split('\n') for sent in sent_list_ori: sent = sent.replace('|', '') if not sent: continue if len(sent) > max_len: sent_list.extend(split_subsentence(sent, min_len=min_len)) else: sent_list.append(sent) except: sent_list.clear() sent_list.append(document_input) assert sum(len(s) for s in sent_list) == len(document_input) p = 0 res = [] for sent in sent_list: res.append([p, sent]) p += len(sent) return res sub_split_flag = [',', ',', ';', ';', ')', ')'] def split_subsentence(sentence, min_len=16): sent = '' for i, c in enumerate(sentence): sent += c if c in sub_split_flag: if i == len(sentence) - 2: yield sent[:-1] + c + sentence[-1] break flag = True for j in range(i + 1, min(len(sentence) - 1, i + 6)): if sentence[j] == ',' or j == len(sentence) - 1: flag = False if (flag and len(sent) >= min_len) or i == len(sentence) - 1: yield sent[:-1] + c sent = '' elif i == len(sentence) - 1: yield sent def split_paragraph_lst(paragraph_lst: List[str], min_len: int = 16, max_len: int = 126): preprocessed = [] for s in paragraph_lst: s = replace_blanks(s) s = s.replace('\r', '').split('\n') for s_ in s: s_ = s_.split('|') preprocessed.extend(s_) paragraph_lst = preprocessed p = 0 offset_lst = [] for s in paragraph_lst: offset_lst.append(p) p += len(s) res = [] for offset_sent, sent in zip(offset_lst, paragraph_lst): sent = sent.replace('|', '') if not sent.strip(): continue if len(sent) > max_len: for offset_subsent, subsent in split_sentence(sent, min_len=min_len, max_len=max_len): if not subsent.strip(): continue res.append([offset_sent + offset_subsent, subsent]) else: res.append([offset_sent, sent]) return res # ==================== 纠错核心 ==================== def clean_model_output(text): text = re.sub(r'.*?', '', text, flags=re.DOTALL) return text.strip() def find_diff_segments(source, target): if source == target: return [] n, m = len(source), len(target) prefix_len = 0 while prefix_len < min(n, m) and source[prefix_len] == target[prefix_len]: prefix_len += 1 suffix_len = 0 while suffix_len < min(n - prefix_len, m - prefix_len) and \ source[n - 1 - suffix_len] == target[m - 1 - suffix_len]: suffix_len += 1 src_diff = source[prefix_len:n - suffix_len] if n - suffix_len > prefix_len else "" tgt_diff = target[prefix_len:m - suffix_len] if m - suffix_len > prefix_len else "" if not src_diff and not tgt_diff: return [] return [{ "original": src_diff, "corrected": tgt_diff, "position": prefix_len, "type": "replace" if src_diff and tgt_diff else ("delete" if src_diff else "insert") }] def correct_single_sentence(sentence: str) -> str: """对单个句子调用模型纠错,返回纠正后的文本""" prompt = "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子,输入句子为:" messages = [{"role": "user", "content": prompt + sentence}] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] raw_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return clean_model_output(raw_output) def text_correction(input_text): logger.info("=" * 60) logger.info(f"[用户输入] {input_text}") if not input_text.strip(): return "请输入需要纠错的文本", "" try: start_time = datetime.now() # 分割段落为子句 segments = split_paragraph_lst([input_text]) logger.info(f"[分句结果] 共 {len(segments)} 个子句") all_errors = {} corrected_parts = [] error_count = 0 for offset, sent in segments: logger.info(f" [子句] offset={offset} | {sent}") corrected = correct_single_sentence(sent) logger.info(f" [纠正] {corrected}") corrected_parts.append(corrected) # 收集差异 diffs = find_diff_segments(sent, corrected) for diff in diffs: error_count += 1 diff["position"] = offset + diff["position"] # 映射回原文位置 all_errors[f"error_{error_count}"] = diff corrected_full = "".join(corrected_parts) duration = (datetime.now() - start_time).total_seconds() logger.info(f"[总耗时] {duration:.2f} 秒") result = {"tgt": corrected_full, "des": all_errors} result_json = json.dumps(result, ensure_ascii=False, indent=2) if all_errors: error_details = "**发现的错误:**\n\n" for key, error in all_errors.items(): error_details += f"- 位置 {error['position']}: `{error['original']}` → `{error['corrected']}`\n" else: error_details = "✅ 未发现错误,句子正确!" output_text = f"**原文:**\n{input_text}\n\n**纠正后:**\n{corrected_full}\n\n{error_details}" logger.info("[处理完成] ✓") return output_text, result_json except Exception as e: logger.error(f"[错误] {str(e)}", exc_info=True) return f"错误: {str(e)}", "" # ==================== Gradio 界面 ==================== with gr.Blocks(title="ChineseErrorCorrector3") as demo: gr.Markdown("# 🔍 ChineseErrorCorrector3") gr.Markdown("支持长段落输入,自动分句后逐句纠错(本地 CPU 推理,句子越多耗时越长)") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="输入文本(支持长段落)", placeholder="例如:他每天都去跑部锻炼身体。对待每一项工作都要一丝不够。", lines=5 ) submit_btn = gr.Button("开始纠错", variant="primary") with gr.Column(): output_display = gr.Markdown(label="纠错结果") with gr.Row(): result_json = gr.Textbox(label="JSON 格式输出", lines=10, interactive=False) gr.Examples( examples=[ ["我的名字较做小明"], ["他每天都去跑部锻炼身体"] ], inputs=input_text ) submit_btn.click(fn=text_correction, inputs=input_text, outputs=[output_display, result_json]) input_text.submit(fn=text_correction, inputs=input_text, outputs=[output_display, result_json]) if __name__ == "__main__": logger.info("启动中文文本纠错助手...") demo.launch()