lllouo commited on
Commit
0bd867c
·
1 Parent(s): 7a91a9a
Files changed (2) hide show
  1. app.py +153 -33
  2. requirements.txt +2 -1
app.py CHANGED
@@ -11,6 +11,17 @@ import spacy
11
  from spellchecker import SpellChecker
12
  import difflib
13
 
 
 
 
 
 
 
 
 
 
 
 
14
  # ======================== API配置 ========================
15
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
16
  DEEPSEEK_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
@@ -60,6 +71,56 @@ Next, please correct the following sentence according to the above requirements.
60
 
61
  [input]: """
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # ======================== 新增:颜色对比函数 ========================
64
  def generate_colored_diff(original, cleaned):
65
  """
@@ -163,12 +224,13 @@ def create_comparison_html(original_list, cleaned_list):
163
  return html
164
 
165
  # ======================== 工具函数 ========================
166
- def check_api_key():
167
- if not DEEPSEEK_API_KEY:
 
168
  raise ValueError("⚠️ 请在 Space Settings 中配置 DEEPSEEK_API_KEY!")
169
 
170
  def call_deepseek_api(prompt, model="deepseek-r1-distill-llama-8b", temperature=0.1, stream=True):
171
- check_api_key()
172
  client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_BASE_URL)
173
  completion = client.chat.completions.create(
174
  model=model,
@@ -286,13 +348,19 @@ def search_leaderboard(df, query):
286
  return df
287
  return df[df['Benchmark'].str.contains(query, case=False, na=False)]
288
 
289
- # ======================== 数据清洗函数(修改版)========================
290
  def clean_dataset(file_path, question_column, model_choice, temperature, max_samples, progress=gr.Progress()):
291
  try:
 
292
  try:
293
- check_api_key()
294
  except ValueError as e:
295
- return str(e), None, ""
 
 
 
 
 
296
 
297
  progress(0.05, desc="📁 读取数据文件...")
298
  df = pd.read_parquet(file_path)
@@ -309,12 +377,18 @@ def clean_dataset(file_path, question_column, model_choice, temperature, max_sam
309
  war_original = calculate_whitespace_anomaly_rate(original_sentences)
310
  sed_original = calculate_spelling_error_density(original_sentences)
311
 
312
- progress(0.1, desc=f"🚀 开始清洗 {total} 个样本...")
 
 
 
 
 
 
313
 
314
- data_corrupt = [process_sentence(str(item)) for item in data_ori]
315
  results = []
316
- max_retries = 5
317
- log_text = f"🚀 开始处理 {total} 个样本...\n\n"
 
318
 
319
  for idx in range(total):
320
  progress((0.1 + 0.7 * idx / total), desc=f"处理中: {idx+1}/{total}")
@@ -326,21 +400,33 @@ def clean_dataset(file_path, question_column, model_choice, temperature, max_sam
326
 
327
  while retry_count < max_retries:
328
  try:
329
- response_content = call_deepseek_api(
330
- PROMPT_TEMPLATE + original_text,
331
- model=model_choice,
332
- temperature=float(temperature)
333
- )
 
 
 
 
334
 
335
- if is_valid_output(response_content, original_text, unprocess_text):
336
- results.append(response_content)
337
- break
 
 
 
 
338
  else:
339
- retry_count += 1
 
 
 
 
340
 
341
  except Exception as e:
342
  retry_count += 1
343
- log_text += f"⚠️ 样本 {idx+1} API错误,重试 {retry_count}/{max_retries}: {str(e)}\n"
344
  else:
345
  results.append(f"[ERROR] Failed to process: {original_text}")
346
  log_text += f"❌ 样本 {idx+1} 处理失败\n"
@@ -364,7 +450,7 @@ def clean_dataset(file_path, question_column, model_choice, temperature, max_sam
364
  lst_final = []
365
  for i in range(len(data_ori)):
366
  item = str(data_ori[i])
367
- if '\n' in item:
368
  tmp_lines = [line.strip() for line in item.strip().split('\n') if line.strip()]
369
  tmp_lines[-1] = lst_extracted[i]
370
  lst_final.append('\n'.join(tmp_lines))
@@ -386,7 +472,8 @@ def clean_dataset(file_path, question_column, model_choice, temperature, max_sam
386
 
387
  original_filename = os.path.basename(file_path)
388
  base_name = original_filename.replace('.parquet', '')
389
- output_filename = f"{base_name}-Denoising.parquet"
 
390
  output_path = os.path.join(tempfile.gettempdir(), output_filename)
391
 
392
  df_cleaned.to_parquet(output_path, index=False)
@@ -394,6 +481,7 @@ def clean_dataset(file_path, question_column, model_choice, temperature, max_sam
394
  log_text += f"\n\n📊 处理完成!\n"
395
  log_text += f"{'='*50}\n"
396
  log_text += f"【基础统计】\n"
 
397
  log_text += f"- 总样本数: {total}\n"
398
  log_text += f"- 成功处理: {total - error_count - unknown_count}\n"
399
  log_text += f"- 失败样本: {error_count}\n"
@@ -408,6 +496,10 @@ def clean_dataset(file_path, question_column, model_choice, temperature, max_sam
408
  log_text += f"📍 拼写错误密度(SED):\n"
409
  log_text += f" 原始: {sed_original:.2f}% → 清洗后: {sed_cleaned:.2f}%\n"
410
  log_text += f" ��化: {delta_sed:+.2f}% {'✅ 改善' if delta_sed < 0 else '⚠️ 增加'}\n"
 
 
 
 
411
  log_text += f"{'='*50}\n"
412
 
413
  # 生成带颜色的对比HTML
@@ -426,27 +518,39 @@ def clean_dataset(file_path, question_column, model_choice, temperature, max_sam
426
  ABOUT_TEXT = """
427
  ## 清洗流程说明
428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  ### 核心算法
430
 
431
  1. **预处理 (process_sentence)**
432
  - 检测句子完整性
433
- - 为不完整的句子添加标记 `___`
434
  - 保留多行文本格式
435
 
436
- 2. **LLM清洗**
437
- - 使用 DeepSeek API 进行语法、拼写、空格错误修正
438
- - 重试机制最多重试5
439
- - 稳定的 REST API 调用
440
 
441
- 3. **格式验证 (is_valid_output)**
442
  - 验证输出格式正确性
443
- - 检查是否保留了 `___` 标记
444
  - 长度合理性检查
445
 
446
  4. **后处理**
447
  - 提取清洗后的内容
448
  - 恢复原始多行格式
449
- - 生成 `XXX-Denoising.parquet` 文件
450
 
451
  ### 支持的数据集
452
 
@@ -466,9 +570,11 @@ ABOUT_TEXT = """
466
  ### 技术栈
467
 
468
  - **LLM**: DeepSeek API (deepseek-r1-distill-llama-8b)
 
469
  - **前端**: Gradio 4.16.0
470
  - **数据处理**: Pandas + PyArrow (Parquet)
471
  - **差异对比**: Python difflib
 
472
  - **API调用**: OpenAI SDK
473
  - **部署**: Hugging Face Spaces
474
 
@@ -477,9 +583,16 @@ ABOUT_TEXT = """
477
  - **WAR (Whitespace Anomaly Rate)**: 空白符异常率
478
  - **SED (Spelling Error Density)**: 拼写错误密度
479
 
 
 
 
 
 
 
 
480
  ---
481
 
482
- **研究生毕业论文成果展示** | Powered by DeepSeek API
483
  """
484
 
485
  # ======================== Gradio界面 ========================
@@ -558,6 +671,11 @@ with demo:
558
  with gr.TabItem("🚀 BD-toolkit Demo", id=3):
559
  gr.Markdown("## BD-toolkit轻量化Demo展示")
560
 
 
 
 
 
 
561
  with gr.Row():
562
  with gr.Column():
563
  file_input = gr.File(
@@ -574,7 +692,8 @@ with demo:
574
  model_choice = gr.Dropdown(
575
  choices=["deepseek-r1-distill-llama-8b", "WAC-GEC"],
576
  value="deepseek-r1-distill-llama-8b",
577
- label="🤖 选择模型"
 
578
  )
579
 
580
  temperature = gr.Slider(
@@ -582,7 +701,8 @@ with demo:
582
  maximum=1.0,
583
  value=0.1,
584
  step=0.1,
585
- label="🌡️ Temperature"
 
586
  )
587
 
588
  max_samples = gr.Slider(
 
11
  from spellchecker import SpellChecker
12
  import difflib
13
 
14
+ # ======================== 新增:WAC-GEC导入 ========================
15
+ try:
16
+ from whitespace_correction import WhitespaceCorrector
17
+ WAC_GEC_AVAILABLE = True
18
+ # 初始化WAC-GEC模型(使用CPU,HF Space通常没有GPU)
19
+ wac_corrector = None # 延迟初始化
20
+ except ImportError:
21
+ WAC_GEC_AVAILABLE = False
22
+ wac_corrector = None
23
+ print("⚠️ whitespace_correction未安装,WAC-GEC功能将不可用")
24
+
25
  # ======================== API配置 ========================
26
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
27
  DEEPSEEK_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
 
71
 
72
  [input]: """
73
 
74
+ # ======================== 新增:WAC-GEC初始化函数 ========================
75
+ def initialize_wac_gec():
76
+ """延迟初始化WAC-GEC模型"""
77
+ global wac_corrector
78
+ if not WAC_GEC_AVAILABLE:
79
+ return False
80
+
81
+ if wac_corrector is None:
82
+ try:
83
+ # 根据环境选择设备
84
+ device = "cpu" # HF Space默认使用CPU
85
+ # 如果有GPU可用,取消下面的注释
86
+ # import torch
87
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
88
+
89
+ # 优先使用本地模型,如果不存在则自动下载
90
+ local_model_path = "./models" # HF Space中的模型目录
91
+ if os.path.exists(os.path.join(local_model_path, "eo_larger_byte")):
92
+ wac_corrector = WhitespaceCorrector.from_pretrained(
93
+ model="eo_larger_byte",
94
+ device=device,
95
+ download_dir=local_model_path
96
+ )
97
+ else:
98
+ # 如果本地没有,自动下载到默认缓存
99
+ wac_corrector = WhitespaceCorrector.from_pretrained(
100
+ model="eo_larger_byte",
101
+ device=device,
102
+ download_dir=None
103
+ )
104
+ print(f"✅ WAC-GEC模型已加载 (设备: {device})")
105
+ return True
106
+ except Exception as e:
107
+ print(f"❌ WAC-GEC模型加载失败: {e}")
108
+ return False
109
+ return True
110
+
111
+ # ======================== 新增:WAC-GEC处理函数 ========================
112
+ def call_wac_gec(text):
113
+ """使用WAC-GEC纠正空白符错误"""
114
+ if not initialize_wac_gec():
115
+ raise ValueError("⚠️ WAC-GEC模型未安装或加载失败")
116
+
117
+ try:
118
+ corrected = wac_corrector.correct_text(text)
119
+ # 格式化输出以匹配DeepSeek的格式
120
+ return f"[output]: {corrected}"
121
+ except Exception as e:
122
+ raise Exception(f"WAC-GEC处理错误: {str(e)}")
123
+
124
  # ======================== 新增:颜色对比函数 ========================
125
  def generate_colored_diff(original, cleaned):
126
  """
 
224
  return html
225
 
226
  # ======================== 工具函数 ========================
227
+ def check_api_key(model_choice):
228
+ """检查API密钥(仅DeepSeek需要)"""
229
+ if model_choice == "deepseek-r1-distill-llama-8b" and not DEEPSEEK_API_KEY:
230
  raise ValueError("⚠️ 请在 Space Settings 中配置 DEEPSEEK_API_KEY!")
231
 
232
  def call_deepseek_api(prompt, model="deepseek-r1-distill-llama-8b", temperature=0.1, stream=True):
233
+ check_api_key(model)
234
  client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_BASE_URL)
235
  completion = client.chat.completions.create(
236
  model=model,
 
348
  return df
349
  return df[df['Benchmark'].str.contains(query, case=False, na=False)]
350
 
351
+ # ======================== 数据清洗函数(修改版:支持双模型)========================
352
  def clean_dataset(file_path, question_column, model_choice, temperature, max_samples, progress=gr.Progress()):
353
  try:
354
+ # 检查API密钥(仅DeepSeek需要)
355
  try:
356
+ check_api_key(model_choice)
357
  except ValueError as e:
358
+ if model_choice == "deepseek-r1-distill-llama-8b":
359
+ return str(e), None, ""
360
+
361
+ # 检查WAC-GEC可用性
362
+ if model_choice == "WAC-GEC" and not WAC_GEC_AVAILABLE:
363
+ return "❌ WAC-GEC模型未安装!请安装 whitespace_correction 包。", None, ""
364
 
365
  progress(0.05, desc="📁 读取数据文件...")
366
  df = pd.read_parquet(file_path)
 
377
  war_original = calculate_whitespace_anomaly_rate(original_sentences)
378
  sed_original = calculate_spelling_error_density(original_sentences)
379
 
380
+ progress(0.1, desc=f"🚀 开始清洗 {total} 个样本 (模型: {model_choice})...")
381
+
382
+ # WAC-GEC不需要添加___标记
383
+ if model_choice == "WAC-GEC":
384
+ data_corrupt = [str(item) for item in data_ori]
385
+ else:
386
+ data_corrupt = [process_sentence(str(item)) for item in data_ori]
387
 
 
388
  results = []
389
+ max_retries = 5 if model_choice == "deepseek-r1-distill-llama-8b" else 3
390
+ log_text = f"🚀 开始处理 {total} 个样本...\n"
391
+ log_text += f"📌 使用模型: {model_choice}\n\n"
392
 
393
  for idx in range(total):
394
  progress((0.1 + 0.7 * idx / total), desc=f"处理中: {idx+1}/{total}")
 
400
 
401
  while retry_count < max_retries:
402
  try:
403
+ # 根据模型选择调用不同的API
404
+ if model_choice == "WAC-GEC":
405
+ response_content = call_wac_gec(original_text)
406
+ else:
407
+ response_content = call_deepseek_api(
408
+ PROMPT_TEMPLATE + original_text,
409
+ model=model_choice,
410
+ temperature=float(temperature)
411
+ )
412
 
413
+ # WAC-GEC的输出格式简单,无需复杂验证
414
+ if model_choice == "WAC-GEC":
415
+ if response_content.startswith('[output]:'):
416
+ results.append(response_content)
417
+ break
418
+ else:
419
+ retry_count += 1
420
  else:
421
+ if is_valid_output(response_content, original_text, unprocess_text):
422
+ results.append(response_content)
423
+ break
424
+ else:
425
+ retry_count += 1
426
 
427
  except Exception as e:
428
  retry_count += 1
429
+ log_text += f"⚠️ 样本 {idx+1} 处理错误,重试 {retry_count}/{max_retries}: {str(e)}\n"
430
  else:
431
  results.append(f"[ERROR] Failed to process: {original_text}")
432
  log_text += f"❌ 样本 {idx+1} 处理失败\n"
 
450
  lst_final = []
451
  for i in range(len(data_ori)):
452
  item = str(data_ori[i])
453
+ if '\n' in item and model_choice != "WAC-GEC":
454
  tmp_lines = [line.strip() for line in item.strip().split('\n') if line.strip()]
455
  tmp_lines[-1] = lst_extracted[i]
456
  lst_final.append('\n'.join(tmp_lines))
 
472
 
473
  original_filename = os.path.basename(file_path)
474
  base_name = original_filename.replace('.parquet', '')
475
+ model_suffix = "WAC-GEC" if model_choice == "WAC-GEC" else "DeepSeek"
476
+ output_filename = f"{base_name}-Denoising-{model_suffix}.parquet"
477
  output_path = os.path.join(tempfile.gettempdir(), output_filename)
478
 
479
  df_cleaned.to_parquet(output_path, index=False)
 
481
  log_text += f"\n\n📊 处理完成!\n"
482
  log_text += f"{'='*50}\n"
483
  log_text += f"【基础统计】\n"
484
+ log_text += f"- 使用模型: {model_choice}\n"
485
  log_text += f"- 总样本数: {total}\n"
486
  log_text += f"- 成功处理: {total - error_count - unknown_count}\n"
487
  log_text += f"- 失败样本: {error_count}\n"
 
496
  log_text += f"📍 拼写错误密度(SED):\n"
497
  log_text += f" 原始: {sed_original:.2f}% → 清洗后: {sed_cleaned:.2f}%\n"
498
  log_text += f" ��化: {delta_sed:+.2f}% {'✅ 改善' if delta_sed < 0 else '⚠️ 增加'}\n"
499
+
500
+ if model_choice == "WAC-GEC":
501
+ log_text += f"\n💡 注意: WAC-GEC仅修正空白符错误,不修正拼写和语法错误\n"
502
+
503
  log_text += f"{'='*50}\n"
504
 
505
  # 生成带颜色的对比HTML
 
518
  ABOUT_TEXT = """
519
  ## 清洗流程说明
520
 
521
+ ### 支持的模型
522
+
523
+ #### 1. DeepSeek-R1 (deepseek-r1-distill-llama-8b)
524
+ - **功能**: 全面的语法、拼写、空格错误修正
525
+ - **优势**: 综合性强,能处理多种类型的错误
526
+ - **配置**: 需要在Space Settings中配置DEEPSEEK_API_KEY
527
+
528
+ #### 2. WAC-GEC (Whitespace Correction)
529
+ - **功能**: 专注于空白符错误纠正(多余空格、缺失空格等)
530
+ - **优势**: 轻量级,无需API密钥,处理速度快
531
+ - **限制**: 仅修正空白符错误,不处理拼写和语法问题
532
+ - **适用场景**: 数据集中主要存在空白符异常的情况
533
+
534
  ### 核心算法
535
 
536
  1. **预处理 (process_sentence)**
537
  - 检测句子完整性
538
+ - 为不完整的句子添加标记 `___` (仅DeepSeek)
539
  - 保留多行文本格式
540
 
541
+ 2. **模型清洗**
542
+ - **DeepSeek**: 使用API进行全面错误修正,重试机制最多5次
543
+ - **WAC-GEC**: 使用本地模型进行空白符纠正,重试机制最多3
 
544
 
545
+ 3. **格式验证**
546
  - 验证输出格式正确性
547
+ - 检查标记保留情况
548
  - 长度合理性检查
549
 
550
  4. **后处理**
551
  - 提取清洗后的内容
552
  - 恢复原始多行格式
553
+ - 生成带模型标识的Parquet文件
554
 
555
  ### 支持的数据集
556
 
 
570
  ### 技术栈
571
 
572
  - **LLM**: DeepSeek API (deepseek-r1-distill-llama-8b)
573
+ - **本地模型**: WAC-GEC (Whitespace Correction)
574
  - **前端**: Gradio 4.16.0
575
  - **数据处理**: Pandas + PyArrow (Parquet)
576
  - **差异对比**: Python difflib
577
+ - **NLP工具**: spaCy, pyspellchecker
578
  - **API调用**: OpenAI SDK
579
  - **部署**: Hugging Face Spaces
580
 
 
583
  - **WAR (Whitespace Anomaly Rate)**: 空白符异常率
584
  - **SED (Spelling Error Density)**: 拼写错误密度
585
 
586
+ ### 模型选择建议
587
+
588
+ - **需要全面清洗**: 选择 DeepSeek-R1
589
+ - **仅需修正空格**: 选择 WAC-GEC(更快,无需API)
590
+ - **预算有限**: 优先使用 WAC-GEC
591
+ - **追求最佳效果**: 使用 DeepSeek-R1
592
+
593
  ---
594
 
595
+ **研究生毕业论文成果展示** | Powered by DeepSeek API & WAC-GEC
596
  """
597
 
598
  # ======================== Gradio界面 ========================
 
671
  with gr.TabItem("🚀 BD-toolkit Demo", id=3):
672
  gr.Markdown("## BD-toolkit轻量化Demo展示")
673
 
674
+ # 模型可用性提示
675
+ model_status = "✅ DeepSeek-R1: " + ("已配置" if DEEPSEEK_API_KEY else "未配置API密钥")
676
+ model_status += " | ✅ WAC-GEC: " + ("可用" if WAC_GEC_AVAILABLE else "未安装")
677
+ gr.Markdown(f"**模型状态**: {model_status}")
678
+
679
  with gr.Row():
680
  with gr.Column():
681
  file_input = gr.File(
 
692
  model_choice = gr.Dropdown(
693
  choices=["deepseek-r1-distill-llama-8b", "WAC-GEC"],
694
  value="deepseek-r1-distill-llama-8b",
695
+ label="🤖 选择模型",
696
+ info="DeepSeek: 全面纠错 | WAC-GEC: 仅空白符纠正"
697
  )
698
 
699
  temperature = gr.Slider(
 
701
  maximum=1.0,
702
  value=0.1,
703
  step=0.1,
704
+ label="🌡️ Temperature",
705
+ info="仅对DeepSeek生效"
706
  )
707
 
708
  max_samples = gr.Slider(
requirements.txt CHANGED
@@ -4,4 +4,5 @@ pandas
4
  pyarrow
5
  openai
6
  spacy
7
- pyspellchecker
 
 
4
  pyarrow
5
  openai
6
  spacy
7
+ pyspellchecker
8
+ whitespace-correction