Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| import gradio as gr | |
| import re | |
| import logging | |
| from datetime import datetime | |
| import json | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import List | |
| # ==================== 日志配置 ==================== | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('text_correction.log', encoding='utf-8'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ==================== 加载模型 ==================== | |
| logger.info("正在加载模型,请稍候...") | |
| model_name = "twnlp/ChineseErrorCorrector3-4B" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, # 内存减半,现代 CPU 均支持 | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, # 加载时减少峰值内存占用 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| logger.info("模型加载完成 ✓") | |
| # ==================== 段落分割 ==================== | |
| blanks = ["\ufeff", "\u3000", "\u2002", "\xa0", "\x07", "\x0b", "\x0c", "_", "_", "\u200d", "\u200c"] | |
| def replace_blanks(text): | |
| for blank in blanks: | |
| text = text.replace(blank, " ") | |
| return text | |
| def split_sentence(document_input: str, min_len: int = 16, max_len: int = 126): | |
| sent_list = [] | |
| try: | |
| punctuation_flag = re.search( | |
| r"""[^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}""", | |
| document_input | |
| ) | |
| if punctuation_flag: | |
| document = re.sub( | |
| r"""(?P<quotation_mark>([^\w《》""【】\[\]<>()()〔〕「」『』〖〗〈〉﹛﹜{}×—-\-%%¥$□℃\xa0\u3000\r\n \t]{2,}))""", | |
| r'\g<quotation_mark>\n', document_input | |
| ) | |
| else: | |
| document = re.sub( | |
| r"""(?P<quotation_mark>([。?!…?!|](?!["'"\'])))""", | |
| r'\g<quotation_mark>\n', document_input | |
| ) | |
| document = re.sub( | |
| r"""(?P<quotation_mark>(([。?!!?|]|…{1,2})["'"\']))""", | |
| r'\g<quotation_mark>\n', document | |
| ) | |
| sent_list_ori = document.split('\n') | |
| for sent in sent_list_ori: | |
| sent = sent.replace('|', '') | |
| if not sent: | |
| continue | |
| if len(sent) > max_len: | |
| sent_list.extend(split_subsentence(sent, min_len=min_len)) | |
| else: | |
| sent_list.append(sent) | |
| except: | |
| sent_list.clear() | |
| sent_list.append(document_input) | |
| assert sum(len(s) for s in sent_list) == len(document_input) | |
| p = 0 | |
| res = [] | |
| for sent in sent_list: | |
| res.append([p, sent]) | |
| p += len(sent) | |
| return res | |
| sub_split_flag = [',', ',', ';', ';', ')', ')'] | |
| def split_subsentence(sentence, min_len=16): | |
| sent = '' | |
| for i, c in enumerate(sentence): | |
| sent += c | |
| if c in sub_split_flag: | |
| if i == len(sentence) - 2: | |
| yield sent[:-1] + c + sentence[-1] | |
| break | |
| flag = True | |
| for j in range(i + 1, min(len(sentence) - 1, i + 6)): | |
| if sentence[j] == ',' or j == len(sentence) - 1: | |
| flag = False | |
| if (flag and len(sent) >= min_len) or i == len(sentence) - 1: | |
| yield sent[:-1] + c | |
| sent = '' | |
| elif i == len(sentence) - 1: | |
| yield sent | |
| def split_paragraph_lst(paragraph_lst: List[str], min_len: int = 16, max_len: int = 126): | |
| preprocessed = [] | |
| for s in paragraph_lst: | |
| s = replace_blanks(s) | |
| s = s.replace('\r', '').split('\n') | |
| for s_ in s: | |
| s_ = s_.split('|') | |
| preprocessed.extend(s_) | |
| paragraph_lst = preprocessed | |
| p = 0 | |
| offset_lst = [] | |
| for s in paragraph_lst: | |
| offset_lst.append(p) | |
| p += len(s) | |
| res = [] | |
| for offset_sent, sent in zip(offset_lst, paragraph_lst): | |
| sent = sent.replace('|', '') | |
| if not sent.strip(): | |
| continue | |
| if len(sent) > max_len: | |
| for offset_subsent, subsent in split_sentence(sent, min_len=min_len, max_len=max_len): | |
| if not subsent.strip(): | |
| continue | |
| res.append([offset_sent + offset_subsent, subsent]) | |
| else: | |
| res.append([offset_sent, sent]) | |
| return res | |
| # ==================== 纠错核心 ==================== | |
| def clean_model_output(text): | |
| text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL) | |
| return text.strip() | |
| def find_diff_segments(source, target): | |
| if source == target: | |
| return [] | |
| n, m = len(source), len(target) | |
| prefix_len = 0 | |
| while prefix_len < min(n, m) and source[prefix_len] == target[prefix_len]: | |
| prefix_len += 1 | |
| suffix_len = 0 | |
| while suffix_len < min(n - prefix_len, m - prefix_len) and \ | |
| source[n - 1 - suffix_len] == target[m - 1 - suffix_len]: | |
| suffix_len += 1 | |
| src_diff = source[prefix_len:n - suffix_len] if n - suffix_len > prefix_len else "" | |
| tgt_diff = target[prefix_len:m - suffix_len] if m - suffix_len > prefix_len else "" | |
| if not src_diff and not tgt_diff: | |
| return [] | |
| return [{ | |
| "original": src_diff, | |
| "corrected": tgt_diff, | |
| "position": prefix_len, | |
| "type": "replace" if src_diff and tgt_diff else ("delete" if src_diff else "insert") | |
| }] | |
| def correct_single_sentence(sentence: str) -> str: | |
| """对单个句子调用模型纠错,返回纠正后的文本""" | |
| prompt = "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子,输入句子为:" | |
| messages = [{"role": "user", "content": prompt + sentence}] | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, enable_thinking=False | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] | |
| for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| raw_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return clean_model_output(raw_output) | |
| def text_correction(input_text): | |
| logger.info("=" * 60) | |
| logger.info(f"[用户输入] {input_text}") | |
| if not input_text.strip(): | |
| return "请输入需要纠错的文本", "" | |
| try: | |
| start_time = datetime.now() | |
| # 分割段落为子句 | |
| segments = split_paragraph_lst([input_text]) | |
| logger.info(f"[分句结果] 共 {len(segments)} 个子句") | |
| all_errors = {} | |
| corrected_parts = [] | |
| error_count = 0 | |
| for offset, sent in segments: | |
| logger.info(f" [子句] offset={offset} | {sent}") | |
| corrected = correct_single_sentence(sent) | |
| logger.info(f" [纠正] {corrected}") | |
| corrected_parts.append(corrected) | |
| # 收集差异 | |
| diffs = find_diff_segments(sent, corrected) | |
| for diff in diffs: | |
| error_count += 1 | |
| diff["position"] = offset + diff["position"] # 映射回原文位置 | |
| all_errors[f"error_{error_count}"] = diff | |
| corrected_full = "".join(corrected_parts) | |
| duration = (datetime.now() - start_time).total_seconds() | |
| logger.info(f"[总耗时] {duration:.2f} 秒") | |
| result = {"tgt": corrected_full, "des": all_errors} | |
| result_json = json.dumps(result, ensure_ascii=False, indent=2) | |
| if all_errors: | |
| error_details = "**发现的错误:**\n\n" | |
| for key, error in all_errors.items(): | |
| error_details += f"- 位置 {error['position']}: `{error['original']}` → `{error['corrected']}`\n" | |
| else: | |
| error_details = "✅ 未发现错误,句子正确!" | |
| output_text = f"**原文:**\n{input_text}\n\n**纠正后:**\n{corrected_full}\n\n{error_details}" | |
| logger.info("[处理完成] ✓") | |
| return output_text, result_json | |
| except Exception as e: | |
| logger.error(f"[错误] {str(e)}", exc_info=True) | |
| return f"错误: {str(e)}", "" | |
| # ==================== Gradio 界面 ==================== | |
| with gr.Blocks(title="ChineseErrorCorrector3") as demo: | |
| gr.Markdown("# 🔍 ChineseErrorCorrector3") | |
| gr.Markdown("支持长段落输入,自动分句后逐句纠错(本地 CPU 推理,句子越多耗时越长)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="输入文本(支持长段落)", | |
| placeholder="例如:他每天都去跑部锻炼身体。对待每一项工作都要一丝不够。", | |
| lines=5 | |
| ) | |
| submit_btn = gr.Button("开始纠错", variant="primary") | |
| with gr.Column(): | |
| output_display = gr.Markdown(label="纠错结果") | |
| with gr.Row(): | |
| result_json = gr.Textbox(label="JSON 格式输出", lines=10, interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| ["我的名字较做小明"], | |
| ["他每天都去跑部锻炼身体"] | |
| ], | |
| inputs=input_text | |
| ) | |
| submit_btn.click(fn=text_correction, inputs=input_text, outputs=[output_display, result_json]) | |
| input_text.submit(fn=text_correction, inputs=input_text, outputs=[output_display, result_json]) | |
| if __name__ == "__main__": | |
| logger.info("启动中文文本纠错助手...") | |
| demo.launch() |