Spaces:
Sleeping
Sleeping
Upload bayesian_core.py
Browse files- bayesian_core.py +120 -5
bayesian_core.py
CHANGED
|
@@ -349,14 +349,129 @@ class BayesianHierarchicalAnalyzer:
|
|
| 349 |
'heterogeneity': heterogeneity
|
| 350 |
}
|
| 351 |
|
| 352 |
-
def get_model_graph(self):
|
| 353 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
if self.model is None:
|
| 355 |
raise ValueError("請先執行分析")
|
| 356 |
|
| 357 |
try:
|
| 358 |
-
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
except Exception as e:
|
| 361 |
raise Exception(f"無法生成 DAG 圖: {str(e)}")
|
| 362 |
|
|
@@ -369,4 +484,4 @@ class BayesianHierarchicalAnalyzer:
|
|
| 369 |
def clear_session_results(cls, session_id):
|
| 370 |
"""清除特定 session 的結果"""
|
| 371 |
if session_id in cls._session_results:
|
| 372 |
-
del cls._session_results[session_id]
|
|
|
|
| 349 |
'heterogeneity': heterogeneity
|
| 350 |
}
|
| 351 |
|
| 352 |
+
def get_model_graph(self, language='zh'):
|
| 353 |
+
"""
|
| 354 |
+
生成模型 DAG 圖(返回 graphviz 物件)
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
language: 'zh' 中文 | 'en' 英文
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
graphviz.Digraph 物件
|
| 361 |
+
"""
|
| 362 |
if self.model is None:
|
| 363 |
raise ValueError("請先執行分析")
|
| 364 |
|
| 365 |
try:
|
| 366 |
+
import graphviz
|
| 367 |
+
|
| 368 |
+
# 獲取欄位名稱
|
| 369 |
+
control_prefix = getattr(self, 'col_control_win', 'control').replace('_win', '').replace('_battles', '').replace('_total', '')
|
| 370 |
+
treatment_prefix = getattr(self, 'col_treatment_win', 'treatment').replace('_win', '').replace('_battles', '').replace('_total', '')
|
| 371 |
+
control_win = getattr(self, 'col_control_win', 'control_win')
|
| 372 |
+
treatment_win = getattr(self, 'col_treatment_win', 'treatment_win')
|
| 373 |
+
|
| 374 |
+
# 定義中英文對照
|
| 375 |
+
labels_zh = {
|
| 376 |
+
# 先驗參數
|
| 377 |
+
'd': f'整體效應 (d)\n{treatment_prefix} vs {control_prefix}',
|
| 378 |
+
'tau': '精度 (tau)',
|
| 379 |
+
'sigma': '配對間標準差 (σ)',
|
| 380 |
+
|
| 381 |
+
# 階層參數
|
| 382 |
+
'mu': f'基線對數勝算 (μ[i])\n各道館基準',
|
| 383 |
+
'delta': f'配對特定效應 (δ[i])\n各道館{treatment_prefix}優勢',
|
| 384 |
+
'delta_new': f'新配對預測 (δ_new)',
|
| 385 |
+
|
| 386 |
+
# 轉換參數
|
| 387 |
+
f'p_{control_prefix}': f'{control_prefix.capitalize()}勝率 (p_{control_prefix}[i])',
|
| 388 |
+
f'p_{treatment_prefix}': f'{treatment_prefix.capitalize()}勝率 (p_{treatment_prefix}[i])',
|
| 389 |
+
|
| 390 |
+
# 觀測值
|
| 391 |
+
f'{control_win}_obs': f'{control_prefix.capitalize()}勝場觀測 ({control_win}_obs[i])',
|
| 392 |
+
f'{treatment_win}_obs': f'{treatment_prefix.capitalize()}勝場觀測 ({treatment_win}_obs[i])',
|
| 393 |
+
|
| 394 |
+
# 其他
|
| 395 |
+
'or_speed': f'勝算比 (OR)\nexp(d)'
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
labels_en = {
|
| 399 |
+
# Priors
|
| 400 |
+
'd': f'Overall Effect (d)\n{treatment_prefix} vs {control_prefix}',
|
| 401 |
+
'tau': 'Precision (tau)',
|
| 402 |
+
'sigma': 'Between-Pair SD (σ)',
|
| 403 |
+
|
| 404 |
+
# Hierarchy
|
| 405 |
+
'mu': f'Baseline Log-Odds (μ[i])\nGym Baselines',
|
| 406 |
+
'delta': f'Pair-Specific Effect (δ[i])\n{treatment_prefix.capitalize()} Advantage per Gym',
|
| 407 |
+
'delta_new': f'New Pair Prediction (δ_new)',
|
| 408 |
+
|
| 409 |
+
# Transformations
|
| 410 |
+
f'p_{control_prefix}': f'{control_prefix.capitalize()} Win Rate (p_{control_prefix}[i])',
|
| 411 |
+
f'p_{treatment_prefix}': f'{treatment_prefix.capitalize()} Win Rate (p_{treatment_prefix}[i])',
|
| 412 |
+
|
| 413 |
+
# Observations
|
| 414 |
+
f'{control_win}_obs': f'{control_prefix.capitalize()} Wins Obs ({control_win}_obs[i])',
|
| 415 |
+
f'{treatment_win}_obs': f'{treatment_prefix.capitalize()} Wins Obs ({treatment_win}_obs[i])',
|
| 416 |
+
|
| 417 |
+
# Others
|
| 418 |
+
'or_speed': f'Odds Ratio (OR)\nexp(d)'
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
# 選擇語言
|
| 422 |
+
labels = labels_zh if language == 'zh' else labels_en
|
| 423 |
+
|
| 424 |
+
# 使用 PyMC 生成基本圖
|
| 425 |
+
gv_original = pm.model_to_graphviz(self.model)
|
| 426 |
+
|
| 427 |
+
# 創建新圖,使用自定義標籤
|
| 428 |
+
dot = graphviz.Digraph(comment='Bayesian Hierarchical Model')
|
| 429 |
+
dot.attr(rankdir='TB') # 從上到下
|
| 430 |
+
dot.attr('node', shape='ellipse', style='filled', fontname='Arial')
|
| 431 |
+
|
| 432 |
+
# 解析原始圖的結構
|
| 433 |
+
lines = gv_original.source.split('\n')
|
| 434 |
+
|
| 435 |
+
# 重建圖,替換標籤
|
| 436 |
+
for line in lines:
|
| 437 |
+
line = line.strip()
|
| 438 |
+
|
| 439 |
+
if not line or line.startswith('digraph') or line == '}':
|
| 440 |
+
continue
|
| 441 |
+
|
| 442 |
+
# 處理節點定義
|
| 443 |
+
if '[label=' in line:
|
| 444 |
+
# 提取節點名稱
|
| 445 |
+
node_name = line.split('[')[0].strip()
|
| 446 |
+
|
| 447 |
+
# 使用自定義標籤
|
| 448 |
+
custom_label = labels.get(node_name, node_name)
|
| 449 |
+
|
| 450 |
+
# 設定節點樣式
|
| 451 |
+
if 'obs' in node_name:
|
| 452 |
+
# 觀測值 - 方框,淺紅色
|
| 453 |
+
dot.node(node_name, custom_label, shape='box', fillcolor='lightcoral')
|
| 454 |
+
elif node_name in ['sigma', f'p_{control_prefix}', f'p_{treatment_prefix}']:
|
| 455 |
+
# 確定性節點 - 菱形,灰色
|
| 456 |
+
dot.node(node_name, custom_label, shape='diamond', fillcolor='lightgray')
|
| 457 |
+
elif node_name in ['mu', 'delta']:
|
| 458 |
+
# 階層參數 - 橢圓,綠色
|
| 459 |
+
dot.node(node_name, custom_label, fillcolor='lightgreen')
|
| 460 |
+
else:
|
| 461 |
+
# 先驗參數 - 橢圓,黃色
|
| 462 |
+
dot.node(node_name, custom_label, fillcolor='lightyellow')
|
| 463 |
+
|
| 464 |
+
# 處理邊
|
| 465 |
+
elif '->' in line:
|
| 466 |
+
# 提取邊的定義
|
| 467 |
+
edge_parts = line.replace(';', '').split('->')
|
| 468 |
+
if len(edge_parts) == 2:
|
| 469 |
+
from_node = edge_parts[0].strip()
|
| 470 |
+
to_node = edge_parts[1].strip().split('[')[0].strip()
|
| 471 |
+
dot.edge(from_node, to_node)
|
| 472 |
+
|
| 473 |
+
return dot
|
| 474 |
+
|
| 475 |
except Exception as e:
|
| 476 |
raise Exception(f"無法生成 DAG 圖: {str(e)}")
|
| 477 |
|
|
|
|
| 484 |
def clear_session_results(cls, session_id):
|
| 485 |
"""清除特定 session 的結果"""
|
| 486 |
if session_id in cls._session_results:
|
| 487 |
+
del cls._session_results[session_id]
|