lanny xu commited on
Commit
8821b53
·
1 Parent(s): 2d46508

add async

Browse files
Files changed (1) hide show
  1. hallucination_detector.py +84 -44
hallucination_detector.py CHANGED
@@ -175,13 +175,13 @@ class NLIHallucinationDetector:
175
  sentences = re.split(r'[。!?\.\!\?]\s*', text)
176
  return [s.strip() for s in sentences if s.strip()]
177
 
178
- def detect(self, generation: str, documents: str) -> Dict:
179
  """
180
- 检测幻觉
181
 
182
  Args:
183
  generation: LLM 生成的内容
184
- documents: 参考文档
185
 
186
  Returns:
187
  {
@@ -202,7 +202,19 @@ class NLIHallucinationDetector:
202
  "problematic_sentences": []
203
  }
204
 
205
- # 分割成句子
 
 
 
 
 
 
 
 
 
 
 
 
206
  sentences = self.split_sentences(generation)
207
 
208
  if not sentences:
@@ -220,60 +232,88 @@ class NLIHallucinationDetector:
220
  entailment_count = 0
221
  problematic_sentences = []
222
 
 
223
  for sentence in sentences:
224
  if len(sentence) < 10: # 跳过太短的句子
225
  continue
226
 
227
- try:
228
- # 根据模型类型调整输入格式
229
- if hasattr(self, 'model_name') and 'cross-encoder' in self.model_name:
230
- # Cross-encoder 模型:直接传入两个文本
231
- result = self.nli_model(
232
- f"{documents[:500]} [SEP] {sentence}",
233
- truncation=True,
234
- max_length=512
235
- )
236
- else:
237
- # 传统 NLI 模型使用 text 和 text_pair
238
- result = self.nli_model(
239
- sentence,
240
- documents[:500],
241
- truncation=True,
242
- max_length=512
243
- )
244
 
245
- # 处理结果
246
- if isinstance(result, list) and len(result) > 0:
247
- label = result[0]['label'].lower()
248
- else:
249
- print(f"⚠️ NLI 返回格式异常: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  continue
251
-
252
- if 'contradiction' in label or 'contradict' in label:
253
- contradiction_count += 1
254
- problematic_sentences.append(sentence)
255
- elif 'neutral' in label:
256
- neutral_count += 1
257
- # neutral 只是中立,不一定是幻觉,不加入 problematic_sentences
258
- elif 'entailment' in label or 'entail' in label:
259
- entailment_count += 1
260
 
261
- except Exception as e:
262
- print(f"⚠️ NLI 检测句子失败: {str(e)[:100]}")
263
- import traceback
264
- print(f" 详细错误: {traceback.format_exc()[:200]}")
265
- continue
266
-
267
- # 判断是否有幻觉(只有明确矛盾才算幻觉)
268
- # neutral 表示文档中没有相关信息,但不一定是错误的
 
 
269
  total_sentences = contradiction_count + neutral_count + entailment_count
270
 
271
- # 只有当矛盾句子超过 30% 或者 neutral 超过 80% 才算幻觉
272
  has_hallucination = False
273
  if total_sentences > 0:
274
  contradiction_ratio = contradiction_count / total_sentences
275
  neutral_ratio = neutral_count / total_sentences
 
276
  has_hallucination = (contradiction_ratio > 0.3) or (neutral_ratio > 0.8)
 
 
 
277
 
278
  return {
279
  "has_hallucination": has_hallucination,
 
175
  sentences = re.split(r'[。!?\.\!\?]\s*', text)
176
  return [s.strip() for s in sentences if s.strip()]
177
 
178
+ def detect(self, generation: str, documents) -> Dict:
179
  """
180
+ 检测幻觉(支持多文档最大匹配策略)
181
 
182
  Args:
183
  generation: LLM 生成的内容
184
+ documents: 参考文档 (str 或 List[Document/str])
185
 
186
  Returns:
187
  {
 
202
  "problematic_sentences": []
203
  }
204
 
205
+ # 1. 预处理文档列表
206
+ docs_content = []
207
+ if isinstance(documents, list):
208
+ for doc in documents:
209
+ if hasattr(doc, 'page_content'):
210
+ docs_content.append(doc.page_content)
211
+ else:
212
+ docs_content.append(str(doc))
213
+ else:
214
+ # 如果是单个字符串,尝试按换行符分割,或者作为单文档处理
215
+ docs_content = [str(documents)]
216
+
217
+ # 2. 分割生成内容为句子
218
  sentences = self.split_sentences(generation)
219
 
220
  if not sentences:
 
232
  entailment_count = 0
233
  problematic_sentences = []
234
 
235
+ # 3. 逐句检测 (Max-Entailment Strategy)
236
  for sentence in sentences:
237
  if len(sentence) < 10: # 跳过太短的句子
238
  continue
239
 
240
+ # 默认为 Neutral (找不到支持)
241
+ best_label = "neutral"
242
+ best_score = 0.0
243
+
244
+ # 遍历所有文档块,寻找最佳匹配
245
+ # 只要有一个文档能 Entail (支持) 这个句子,就算通过
246
+ sentence_supported = False
247
+
248
+ for doc_content in docs_content:
249
+ # 截断单个文档块以适应模型 (保留前 800 字符,通常足够覆盖 512 tokens)
250
+ # 注意这里是对单个文档块截断,而不是对所有文档拼接后截断
251
+ premise = doc_content[:800]
 
 
 
 
 
252
 
253
+ try:
254
+ # NLI 推理
255
+ if hasattr(self, 'model_name') and 'cross-encoder' in self.model_name:
256
+ result = self.nli_model(
257
+ f"{premise} [SEP] {sentence}",
258
+ truncation=True,
259
+ max_length=512
260
+ )
261
+ else:
262
+ result = self.nli_model(
263
+ sentence,
264
+ premise,
265
+ truncation=True,
266
+ max_length=512
267
+ )
268
+
269
+ # 解析结果
270
+ if isinstance(result, list) and len(result) > 0:
271
+ current_label = result[0]['label'].lower()
272
+ current_score = result[0]['score']
273
+
274
+ # 优先级逻辑:Entailment > Contradiction > Neutral
275
+ # 如果找到 Entailment,立即停止查找(已验证)
276
+ if 'entailment' in current_label or 'entail' in current_label:
277
+ best_label = "entailment"
278
+ sentence_supported = True
279
+ break
280
+
281
+ # 如果是 Contradiction,记录下来,但继续找(也许其他文档能解释)
282
+ if 'contradiction' in current_label or 'contradict' in current_label:
283
+ # 只有当目前是 Neutral 时才更新为 Contradiction
284
+ # 这样防止 Contradiction 覆盖了潜在的 Entailment (虽然���面break了,但这逻辑保持严谨)
285
+ if best_label == "neutral":
286
+ best_label = "contradiction"
287
+ best_score = current_score
288
+
289
+ else:
290
+ continue
291
+
292
+ except Exception as e:
293
+ print(f"⚠️ NLI 子任务失败: {str(e)[:50]}")
294
  continue
 
 
 
 
 
 
 
 
 
295
 
296
+ # 统计该句子的最终判定
297
+ if best_label == "entailment":
298
+ entailment_count += 1
299
+ elif best_label == "contradiction":
300
+ contradiction_count += 1
301
+ problematic_sentences.append(sentence)
302
+ else: # neutral
303
+ neutral_count += 1
304
+
305
+ # 4. 综合评分
306
  total_sentences = contradiction_count + neutral_count + entailment_count
307
 
 
308
  has_hallucination = False
309
  if total_sentences > 0:
310
  contradiction_ratio = contradiction_count / total_sentences
311
  neutral_ratio = neutral_count / total_sentences
312
+ # 阈值判断
313
  has_hallucination = (contradiction_ratio > 0.3) or (neutral_ratio > 0.8)
314
+
315
+ # Debug 信息
316
+ print(f"📊 NLI 检测结果: Entail={entailment_count}, Contra={contradiction_count}, Neutral={neutral_count}")
317
 
318
  return {
319
  "has_hallucination": has_hallucination,