kofdai commited on
Commit
ef23ea1
·
verified ·
1 Parent(s): b7a3aa6

Upload judge_beta_lobe_advanced.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. judge_beta_lobe_advanced.py +483 -0
judge_beta_lobe_advanced.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import re
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ # ドメイン別の危険な主張を検出するためのパターン
9
+ DOMAIN_DANGEROUS_PATTERNS = {
10
+ "medical": [
11
+ (r"(必ず|絶対に|確実に).*(治る|完治|治癒)", "absolute_cure_claim", "critical"),
12
+ (r"(副作用|リスク).*(ない|ありません|存在しない)", "no_side_effects_claim", "critical"),
13
+ (r"(すべての|全ての|あらゆる)患者に(有効|効果的)", "universal_effectiveness", "high"),
14
+ (r"(西洋|現代)医学.*(不要|いらない|無意味)", "anti_medicine_claim", "critical"),
15
+ (r"(自己判断|自分で).*治療", "self_treatment_encouragement", "moderate"),
16
+ (r"医師.*(相談|受診).*(不要|いらない|必要ない)", "avoid_doctor_claim", "critical"),
17
+ ],
18
+ "legal": [
19
+ (r"(必ず|絶対に|確実に).*(勝訴|勝てる|認められる)", "absolute_outcome_claim", "critical"),
20
+ (r"弁護士.*(不要|いらない|必要ない)", "avoid_lawyer_claim", "critical"),
21
+ (r"(すべての|全ての)ケースで", "universal_applicability", "high"),
22
+ (r"(違法|犯罪).*(ではない|にならない).*絶対", "absolute_legality_claim", "critical"),
23
+ (r"(時効|期限).*(気にしなくて|無視して)", "ignore_deadlines", "critical"),
24
+ (r"(判例|法律).*無視", "ignore_precedent", "high"),
25
+ ],
26
+ "economics": [
27
+ (r"(必ず|絶対に|確実に).*(儲かる|利益|リターン)", "guaranteed_profit_claim", "critical"),
28
+ (r"リスク.*(ない|ゼロ|存在しない)", "no_risk_claim", "critical"),
29
+ (r"(すべての|全ての)投資家に", "universal_advice", "high"),
30
+ (r"(買う|売る)べき.*絶対", "absolute_trading_advice", "critical"),
31
+ (r"市場.*予測.*確実", "certain_market_prediction", "high"),
32
+ (r"(暴落|暴騰).*(ない|しない).*絶対", "absolute_market_stability", "high"),
33
+ ]
34
+ }
35
+
36
+ # 後方互換性のため
37
+ DANGEROUS_CLAIM_PATTERNS = DOMAIN_DANGEROUS_PATTERNS["medical"]
38
+
39
+ # 医学的数値の妥当性範囲
40
+ MEDICAL_VALUE_RANGES = {
41
+ "血圧": {"systolic": (60, 250), "diastolic": (40, 150)},
42
+ "体温": {"min": 35.0, "max": 42.0},
43
+ "心拍数": {"min": 30, "max": 220},
44
+ "SpO2": {"min": 70, "max": 100},
45
+ "血糖値": {"min": 20, "max": 600},
46
+ }
47
+
48
+ # 法学ドメインの検証パターン
49
+ LEGAL_VALIDATION_PATTERNS = {
50
+ "disclaimer_required": r"(免責|情報提供|法的助言ではありません)",
51
+ "statute_citation": r"(第\d+条|条文|法律)",
52
+ "precedent_citation": r"(判例|最判|最決|高判)",
53
+ }
54
+
55
+ # 経済学ドメインの検証パターン
56
+ ECONOMICS_VALIDATION_PATTERNS = {
57
+ "data_source_required": r"(統計|データ|出典|IMF|日銀|内閣府)",
58
+ "uncertainty_disclosure": r"(予測|推計|不確実|シナリオ)",
59
+ "disclaimer_required": r"(投資助言ではありません|自己責任)",
60
+ }
61
+
62
+
63
+ class BetaLobeAdvanced:
64
+ """
65
+ 検証院(β-Lobe)の高度な機能。
66
+ 論理的妥当性、医学的文脈の検証、ハルシネーション検出を実装。
67
+ """
68
+
69
+ def __init__(self, db_interface, medical_ontology):
70
+ self.db = db_interface
71
+ self.ontology = medical_ontology
72
+
73
+ # --- 基本的なAnchor事実チェック ---
74
+ def _is_mentioned(self, fact: str, response: str) -> bool:
75
+ """事実がレスポンスに言及されているか確認"""
76
+ fact_keywords = [word for word in fact.split() if len(word) > 1]
77
+ if not fact_keywords:
78
+ return False
79
+ mentioned_count = sum(1 for kw in fact_keywords if kw in response)
80
+ return (mentioned_count / len(fact_keywords)) > 0.5
81
+
82
+ def _detect_numerical_contradiction(self, fact: str, response: str) -> bool:
83
+ """数値の矛盾を検出"""
84
+ fact_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', fact)
85
+ if not fact_numbers:
86
+ return False
87
+ fact_value = float(fact_numbers[0])
88
+ response_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', response)
89
+ if not response_numbers:
90
+ return True
91
+ # 10%以上の乖離で矛盾とみなす
92
+ is_far = all(abs(float(res_val) - fact_value) / max(fact_value, 0.001) > 0.1 for res_val in response_numbers)
93
+ return is_far
94
+
95
+ async def check_anchor_facts(self, response_text: str, db_context: dict) -> dict:
96
+ """DBの知識タイルと回答の整合性を検証"""
97
+ contradictions = []
98
+
99
+ for coord, tile in db_context.items():
100
+ if not tile:
101
+ continue
102
+
103
+ # タイルから主要な事実を抽出
104
+ anchor_facts = self._extract_anchor_facts(tile)
105
+
106
+ for fact in anchor_facts:
107
+ # 事実が言及されているか確認
108
+ if self._is_mentioned(fact["statement"], response_text):
109
+ # 数値の矛盾をチェック
110
+ if fact.get("has_numbers") and self._detect_numerical_contradiction(fact["statement"], response_text):
111
+ contradictions.append({
112
+ "type": "numerical_contradiction",
113
+ "fact": fact["statement"],
114
+ "source": coord,
115
+ "severity": "high"
116
+ })
117
+
118
+ return {
119
+ "contradictions": contradictions,
120
+ "contradiction_count": len(contradictions),
121
+ "passed": len(contradictions) == 0
122
+ }
123
+
124
+ def _extract_anchor_facts(self, tile: dict) -> list:
125
+ """タイルから検証用の事実を抽出"""
126
+ facts = []
127
+ content = tile.get("content", "") or tile.get("data", "")
128
+
129
+ if isinstance(content, str):
130
+ # 箇条書きや重要な記述を抽出
131
+ lines = content.split("\n")
132
+ for line in lines:
133
+ line = line.strip()
134
+ if len(line) > 10 and any(marker in line for marker in ["は", "である", "です", ":"]):
135
+ has_numbers = bool(re.search(r'\d+', line))
136
+ facts.append({
137
+ "statement": line[:200], # 最大200文字
138
+ "has_numbers": has_numbers
139
+ })
140
+ if len(facts) >= 5: # 最大5つの事実
141
+ break
142
+
143
+ return facts
144
+
145
+ # --- 危険な主張の検出(ドメイン対応) ---
146
+ def _detect_dangerous_claims(self, response: str, domain: str = "medical") -> list:
147
+ """ドメイン別の危険な主張を検出"""
148
+ issues = []
149
+ patterns = DOMAIN_DANGEROUS_PATTERNS.get(domain, DOMAIN_DANGEROUS_PATTERNS["medical"])
150
+
151
+ for pattern, claim_type, severity in patterns:
152
+ match = re.search(pattern, response)
153
+ if match:
154
+ issues.append({
155
+ "type": "dangerous_claim",
156
+ "domain": domain,
157
+ "claim_type": claim_type,
158
+ "matched_text": match.group(0),
159
+ "severity": severity,
160
+ "message": f"危険な主張を検出 [{domain}]: {claim_type}"
161
+ })
162
+ return issues
163
+
164
+ # --- ドメイン固有の検証 ---
165
+ def _validate_legal_response(self, response: str) -> list:
166
+ """法学ドメイン固有の検証"""
167
+ issues = []
168
+
169
+ # 免責事項の確認
170
+ if not re.search(LEGAL_VALIDATION_PATTERNS["disclaimer_required"], response):
171
+ issues.append({
172
+ "type": "missing_disclaimer",
173
+ "domain": "legal",
174
+ "severity": "high",
175
+ "message": "法的免責事項が欠落しています"
176
+ })
177
+
178
+ # 条文引用の確認(法律質問の場合)
179
+ # ここでは警告レベルにとどめる
180
+ if not re.search(LEGAL_VALIDATION_PATTERNS["statute_citation"], response):
181
+ issues.append({
182
+ "type": "missing_citation",
183
+ "domain": "legal",
184
+ "severity": "moderate",
185
+ "message": "条文への参照がありません"
186
+ })
187
+
188
+ return issues
189
+
190
+ def _validate_economics_response(self, response: str) -> list:
191
+ """経済学ドメイン固有の検証"""
192
+ issues = []
193
+
194
+ # データ出典の確認
195
+ if not re.search(ECONOMICS_VALIDATION_PATTERNS["data_source_required"], response):
196
+ issues.append({
197
+ "type": "missing_data_source",
198
+ "domain": "economics",
199
+ "severity": "moderate",
200
+ "message": "データ出典への参照がありません"
201
+ })
202
+
203
+ # 予測の場合の不確実性開示
204
+ if "予測" in response or "見通し" in response:
205
+ if not re.search(ECONOMICS_VALIDATION_PATTERNS["uncertainty_disclosure"], response):
206
+ issues.append({
207
+ "type": "missing_uncertainty_disclosure",
208
+ "domain": "economics",
209
+ "severity": "high",
210
+ "message": "予測の不確実性が明示されていません"
211
+ })
212
+
213
+ return issues
214
+
215
+ # --- 医学的数値の妥当性検証 ---
216
+ def _validate_medical_values(self, response: str) -> list:
217
+ """医学的数値が妥当な範囲内か検証"""
218
+ issues = []
219
+
220
+ # 血圧の検出と検証
221
+ bp_pattern = r'(\d{2,3})/(\d{2,3})\s*(?:mmHg)?'
222
+ bp_matches = re.findall(bp_pattern, response)
223
+ for systolic, diastolic in bp_matches:
224
+ s, d = int(systolic), int(diastolic)
225
+ ranges = MEDICAL_VALUE_RANGES["血圧"]
226
+ if not (ranges["systolic"][0] <= s <= ranges["systolic"][1]):
227
+ issues.append({
228
+ "type": "invalid_medical_value",
229
+ "value_type": "血圧(収縮期)",
230
+ "value": s,
231
+ "expected_range": ranges["systolic"],
232
+ "severity": "high"
233
+ })
234
+ if not (ranges["diastolic"][0] <= d <= ranges["diastolic"][1]):
235
+ issues.append({
236
+ "type": "invalid_medical_value",
237
+ "value_type": "血圧(拡張期)",
238
+ "value": d,
239
+ "expected_range": ranges["diastolic"],
240
+ "severity": "high"
241
+ })
242
+
243
+ # 体温の検出と検証
244
+ temp_pattern = r'(\d{2}(?:\.\d)?)\s*(?:°C|度|℃)'
245
+ temp_matches = re.findall(temp_pattern, response)
246
+ for temp in temp_matches:
247
+ t = float(temp)
248
+ ranges = MEDICAL_VALUE_RANGES["体温"]
249
+ if not (ranges["min"] <= t <= ranges["max"]):
250
+ issues.append({
251
+ "type": "invalid_medical_value",
252
+ "value_type": "体温",
253
+ "value": t,
254
+ "expected_range": (ranges["min"], ranges["max"]),
255
+ "severity": "high"
256
+ })
257
+
258
+ return issues
259
+
260
+ # --- 高度な検証機能 ---
261
+
262
+ def _detect_false_dichotomy(self, response: str) -> list:
263
+ """偽の二者択一を検出"""
264
+ errors = []
265
+ dichotomy_pattern = r"(AかBのいずれかしかない|AかBしかない)" # 簡易パターン
266
+ if re.search(dichotomy_pattern, response.replace(" ","")): # 空白除去
267
+ errors.append({"type": "false_dichotomy", "statement": response, "severity": "moderate"})
268
+ return errors
269
+
270
+ async def _check_logical_consistency(self, question, alpha_response) -> dict:
271
+ """推論の論理的妥当性を検証"""
272
+ errors = []
273
+ response_text = alpha_response["main_response"]
274
+
275
+ # 偽の二者択一を検出
276
+ dichotomy_errors = self._detect_false_dichotomy(response_text)
277
+ errors.extend(dichotomy_errors)
278
+
279
+ # NOTE: 環状論理、論理的飛躍、根拠なき仮定の検出は高度なNLPが必要なため、
280
+ # ここではプレースホルダーとして成功を返す。
281
+
282
+ return {"logical_errors": errors, "error_count": len(errors), "passed": len(errors) == 0}
283
+
284
+ async def _verify_treatment_validity(self, response_text, db_context) -> dict:
285
+ """治療法の妥当性を検証"""
286
+ issues = []
287
+ # 簡易的な治療法抽出
288
+ mentioned_treatments_regex = re.findall(r"(\S+)が良い|(\w+)が有効な治療法|(\w+)を投与", response_text)
289
+ # 抽出結果はタプルのリストになるため、フラット化する
290
+ extracted_phrases = [item for tpl in mentioned_treatments_regex for item in tpl if item]
291
+
292
+ # 後処理で助詞などを除去し、治療法名を正確に切り出す
293
+ processed_treatments = []
294
+ for phrase in extracted_phrases:
295
+ if "には" in phrase:
296
+ processed_treatments.append(phrase.split("には")[-1])
297
+ elif "は" in phrase:
298
+ processed_treatments.append(phrase.split("は")[-1])
299
+ else:
300
+ processed_treatments.append(phrase)
301
+
302
+ for treatment in processed_treatments:
303
+ if not treatment: continue
304
+ treatment_info = await self.db.search_treatment(treatment)
305
+ if not treatment_info:
306
+ issues.append({"type": "unknown_treatment", "treatment": treatment, "severity": "moderate", "message": f"「{treatment}」は未知の治療法"})
307
+ elif not treatment_info.get("is_validated"):
308
+ issues.append({"type": "unvalidated_treatment", "treatment": treatment, "severity": "critical", "message": f"「{treatment}」は未検証の治療法"})
309
+
310
+ return {"valid": len(issues) == 0, "issues": issues}
311
+
312
+ async def _check_medical_context(self, response_text: str, db_context: dict) -> dict:
313
+ """医学的コンテキストが適切か確認"""
314
+ issues = []
315
+ treatment_check = await self._verify_treatment_validity(response_text, db_context)
316
+ if not treatment_check["valid"]:
317
+ issues.extend(treatment_check["issues"])
318
+
319
+ # NOTE: 診断基準、数値、禁忌の検証はプレースホルダー
320
+ return {"issues": issues, "issue_count": len(issues), "passed": len(issues) == 0}
321
+
322
+ async def validate_response(self, question: str, alpha_response: dict, db_context: dict, web_results=None, session_context=None, domain: str = "medical") -> dict:
323
+ """回答を多角的に検証する(基本+高度、ドメイン対応)"""
324
+
325
+ response_text = alpha_response.get("main_response", "")
326
+ # alpha_responseにドメイン情報があればそちらを優先
327
+ domain = alpha_response.get("domain", domain)
328
+
329
+ logger.info(f"BetaLobe検証開始: domain={domain}")
330
+
331
+ # 1. 基本的なAnchor事実チェック
332
+ anchor_check = await self.check_anchor_facts(response_text, db_context)
333
+
334
+ # 2. 高度な論理チェック
335
+ logic_check = await self._check_logical_consistency(question, alpha_response)
336
+
337
+ # 3. ドメイン別の文脈チェック
338
+ if domain == "medical":
339
+ context_check = await self._check_medical_context(response_text, db_context)
340
+ elif domain == "legal":
341
+ context_issues = self._validate_legal_response(response_text)
342
+ context_check = {"issues": context_issues, "issue_count": len(context_issues), "passed": len(context_issues) == 0}
343
+ elif domain == "economics":
344
+ context_issues = self._validate_economics_response(response_text)
345
+ context_check = {"issues": context_issues, "issue_count": len(context_issues), "passed": len(context_issues) == 0}
346
+ else:
347
+ context_check = {"issues": [], "issue_count": 0, "passed": True}
348
+
349
+ # 4. ドメイン別の危険な主張の検出
350
+ dangerous_claims = self._detect_dangerous_claims(response_text, domain)
351
+ safety_check = {
352
+ "issues": dangerous_claims,
353
+ "issue_count": len(dangerous_claims),
354
+ "passed": len(dangerous_claims) == 0
355
+ }
356
+
357
+ # 5. ドメイン別の数値妥当性検証
358
+ if domain == "medical":
359
+ value_issues = self._validate_medical_values(response_text)
360
+ else:
361
+ value_issues = [] # 法学・経済学は数値検証なし(将来拡張可能)
362
+ value_check = {
363
+ "issues": value_issues,
364
+ "issue_count": len(value_issues),
365
+ "passed": len(value_issues) == 0
366
+ }
367
+
368
+ # 全ての問題を集約
369
+ all_issues = (
370
+ anchor_check["contradictions"] +
371
+ logic_check["logical_errors"] +
372
+ context_check["issues"] +
373
+ safety_check["issues"] +
374
+ value_check["issues"]
375
+ )
376
+
377
+ # 重大度を判定
378
+ severity = "none"
379
+ if any(i.get("severity") == "critical" for i in all_issues):
380
+ severity = "critical"
381
+ elif any(i.get("severity") == "high" for i in all_issues):
382
+ severity = "high"
383
+ elif any(i.get("severity") == "moderate" for i in all_issues):
384
+ severity = "moderate"
385
+
386
+ # ハルシネーションリスクスコアを計算
387
+ hallucination_risk = self._calculate_hallucination_risk(
388
+ alpha_response, anchor_check, logic_check, context_check, safety_check
389
+ )
390
+
391
+ validation_result = {
392
+ "timestamp": datetime.now().isoformat(),
393
+ "response_text": response_text[:500], # 長い回答は切り詰め
394
+ "checks": {
395
+ "anchor_facts": anchor_check,
396
+ "logic": logic_check,
397
+ "context": context_check,
398
+ "safety": safety_check,
399
+ "medical_values": value_check
400
+ },
401
+ "all_issues": all_issues,
402
+ "issue_count": len(all_issues),
403
+ "has_contradictions": len(all_issues) > 0,
404
+ "severity": severity,
405
+ "hallucination_risk": hallucination_risk,
406
+ "recommendations": self._generate_recommendations(all_issues)
407
+ }
408
+
409
+ logger.info(f"検証完了: {len(all_issues)}件の問題, 重大度={severity}, ハルシネーションリスク={hallucination_risk['score']:.2f}")
410
+ return validation_result
411
+
412
+ def _calculate_hallucination_risk(self, alpha_response, anchor_check, logic_check, context_check, safety_check) -> dict:
413
+ """ハルシネーションリスクスコアを計算"""
414
+ score = 0.0
415
+
416
+ # Anchor事実との矛盾(最大0.4)
417
+ if not anchor_check["passed"]:
418
+ score += 0.4
419
+
420
+ # 論理エラー(最大0.2)
421
+ if not logic_check["passed"]:
422
+ score += 0.2
423
+
424
+ # 医学的文脈の問題(最大0.15)
425
+ if not context_check["passed"]:
426
+ score += 0.15
427
+
428
+ # 危険な主張(最大0.25)
429
+ if not safety_check["passed"]:
430
+ score += 0.25
431
+
432
+ # 信頼度が低い場合のペナルティ
433
+ confidence = alpha_response.get("confidence", 0.5)
434
+ if confidence < 0.4:
435
+ score += 0.1
436
+
437
+ final_score = min(1.0, score)
438
+
439
+ # リスクレベルの分類
440
+ if final_score < 0.1:
441
+ level = "very_low"
442
+ elif final_score < 0.25:
443
+ level = "low"
444
+ elif final_score < 0.5:
445
+ level = "moderate"
446
+ elif final_score < 0.75:
447
+ level = "high"
448
+ else:
449
+ level = "critical"
450
+
451
+ return {
452
+ "score": final_score,
453
+ "level": level,
454
+ "action_required": final_score >= 0.25
455
+ }
456
+
457
+ def _generate_recommendations(self, all_issues: list) -> list:
458
+ """問題に基づいて修正推奨を生成"""
459
+ recommendations = []
460
+
461
+ for issue in all_issues[:3]: # 最大3件
462
+ issue_type = issue.get("type", "unknown")
463
+
464
+ if issue_type == "dangerous_claim":
465
+ recommendations.append({
466
+ "type": "remove_dangerous_claim",
467
+ "message": f"危険な主張を削除または修正: {issue.get('claim_type')}",
468
+ "priority": "high"
469
+ })
470
+ elif issue_type == "numerical_contradiction":
471
+ recommendations.append({
472
+ "type": "verify_numbers",
473
+ "message": f"数値を確認: {issue.get('fact', '')[:50]}",
474
+ "priority": "medium"
475
+ })
476
+ elif issue_type == "invalid_medical_value":
477
+ recommendations.append({
478
+ "type": "correct_value",
479
+ "message": f"{issue.get('value_type')}の値が範囲外: {issue.get('value')}",
480
+ "priority": "high"
481
+ })
482
+
483
+ return recommendations