Wen1201 commited on
Commit
328bfd2
·
verified ·
1 Parent(s): 17c7c7e

Upload bayesian_core.py

Browse files
Files changed (1) hide show
  1. bayesian_core.py +120 -5
bayesian_core.py CHANGED
@@ -349,14 +349,129 @@ class BayesianHierarchicalAnalyzer:
349
  'heterogeneity': heterogeneity
350
  }
351
 
352
- def get_model_graph(self):
353
- """生成模型 DAG 圖(返回 graphviz 物件)"""
 
 
 
 
 
 
 
 
354
  if self.model is None:
355
  raise ValueError("請先執行分析")
356
 
357
  try:
358
- gv = pm.model_to_graphviz(self.model)
359
- return gv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  except Exception as e:
361
  raise Exception(f"無法生成 DAG 圖: {str(e)}")
362
 
@@ -369,4 +484,4 @@ class BayesianHierarchicalAnalyzer:
369
  def clear_session_results(cls, session_id):
370
  """清除特定 session 的結果"""
371
  if session_id in cls._session_results:
372
- del cls._session_results[session_id]
 
349
  'heterogeneity': heterogeneity
350
  }
351
 
352
+ def get_model_graph(self, language='zh'):
353
+ """
354
+ 生成模型 DAG 圖(返回 graphviz 物件)
355
+
356
+ Args:
357
+ language: 'zh' 中文 | 'en' 英文
358
+
359
+ Returns:
360
+ graphviz.Digraph 物件
361
+ """
362
  if self.model is None:
363
  raise ValueError("請先執行分析")
364
 
365
  try:
366
+ import graphviz
367
+
368
+ # 獲取欄位名稱
369
+ control_prefix = getattr(self, 'col_control_win', 'control').replace('_win', '').replace('_battles', '').replace('_total', '')
370
+ treatment_prefix = getattr(self, 'col_treatment_win', 'treatment').replace('_win', '').replace('_battles', '').replace('_total', '')
371
+ control_win = getattr(self, 'col_control_win', 'control_win')
372
+ treatment_win = getattr(self, 'col_treatment_win', 'treatment_win')
373
+
374
+ # 定義中英文對照
375
+ labels_zh = {
376
+ # 先驗參數
377
+ 'd': f'整體效應 (d)\n{treatment_prefix} vs {control_prefix}',
378
+ 'tau': '精度 (tau)',
379
+ 'sigma': '配對間標準差 (σ)',
380
+
381
+ # 階層參數
382
+ 'mu': f'基線對數勝算 (μ[i])\n各道館基準',
383
+ 'delta': f'配對特定效應 (δ[i])\n各道館{treatment_prefix}優勢',
384
+ 'delta_new': f'新配對預測 (δ_new)',
385
+
386
+ # 轉換參數
387
+ f'p_{control_prefix}': f'{control_prefix.capitalize()}勝率 (p_{control_prefix}[i])',
388
+ f'p_{treatment_prefix}': f'{treatment_prefix.capitalize()}勝率 (p_{treatment_prefix}[i])',
389
+
390
+ # 觀測值
391
+ f'{control_win}_obs': f'{control_prefix.capitalize()}勝場觀測 ({control_win}_obs[i])',
392
+ f'{treatment_win}_obs': f'{treatment_prefix.capitalize()}勝場觀測 ({treatment_win}_obs[i])',
393
+
394
+ # 其他
395
+ 'or_speed': f'勝算比 (OR)\nexp(d)'
396
+ }
397
+
398
+ labels_en = {
399
+ # Priors
400
+ 'd': f'Overall Effect (d)\n{treatment_prefix} vs {control_prefix}',
401
+ 'tau': 'Precision (tau)',
402
+ 'sigma': 'Between-Pair SD (σ)',
403
+
404
+ # Hierarchy
405
+ 'mu': f'Baseline Log-Odds (μ[i])\nGym Baselines',
406
+ 'delta': f'Pair-Specific Effect (δ[i])\n{treatment_prefix.capitalize()} Advantage per Gym',
407
+ 'delta_new': f'New Pair Prediction (δ_new)',
408
+
409
+ # Transformations
410
+ f'p_{control_prefix}': f'{control_prefix.capitalize()} Win Rate (p_{control_prefix}[i])',
411
+ f'p_{treatment_prefix}': f'{treatment_prefix.capitalize()} Win Rate (p_{treatment_prefix}[i])',
412
+
413
+ # Observations
414
+ f'{control_win}_obs': f'{control_prefix.capitalize()} Wins Obs ({control_win}_obs[i])',
415
+ f'{treatment_win}_obs': f'{treatment_prefix.capitalize()} Wins Obs ({treatment_win}_obs[i])',
416
+
417
+ # Others
418
+ 'or_speed': f'Odds Ratio (OR)\nexp(d)'
419
+ }
420
+
421
+ # 選擇語言
422
+ labels = labels_zh if language == 'zh' else labels_en
423
+
424
+ # 使用 PyMC 生成基本圖
425
+ gv_original = pm.model_to_graphviz(self.model)
426
+
427
+ # 創建新圖,使用自定義標籤
428
+ dot = graphviz.Digraph(comment='Bayesian Hierarchical Model')
429
+ dot.attr(rankdir='TB') # 從上到下
430
+ dot.attr('node', shape='ellipse', style='filled', fontname='Arial')
431
+
432
+ # 解析原始圖的結構
433
+ lines = gv_original.source.split('\n')
434
+
435
+ # 重建圖,替換標籤
436
+ for line in lines:
437
+ line = line.strip()
438
+
439
+ if not line or line.startswith('digraph') or line == '}':
440
+ continue
441
+
442
+ # 處理節點定義
443
+ if '[label=' in line:
444
+ # 提取節點名稱
445
+ node_name = line.split('[')[0].strip()
446
+
447
+ # 使用自定義標籤
448
+ custom_label = labels.get(node_name, node_name)
449
+
450
+ # 設定節點樣式
451
+ if 'obs' in node_name:
452
+ # 觀測值 - 方框,淺紅色
453
+ dot.node(node_name, custom_label, shape='box', fillcolor='lightcoral')
454
+ elif node_name in ['sigma', f'p_{control_prefix}', f'p_{treatment_prefix}']:
455
+ # 確定性節點 - 菱形,灰色
456
+ dot.node(node_name, custom_label, shape='diamond', fillcolor='lightgray')
457
+ elif node_name in ['mu', 'delta']:
458
+ # 階層參數 - 橢圓,綠色
459
+ dot.node(node_name, custom_label, fillcolor='lightgreen')
460
+ else:
461
+ # 先驗參數 - 橢圓,黃色
462
+ dot.node(node_name, custom_label, fillcolor='lightyellow')
463
+
464
+ # 處理邊
465
+ elif '->' in line:
466
+ # 提取邊的定義
467
+ edge_parts = line.replace(';', '').split('->')
468
+ if len(edge_parts) == 2:
469
+ from_node = edge_parts[0].strip()
470
+ to_node = edge_parts[1].strip().split('[')[0].strip()
471
+ dot.edge(from_node, to_node)
472
+
473
+ return dot
474
+
475
  except Exception as e:
476
  raise Exception(f"無法生成 DAG 圖: {str(e)}")
477
 
 
484
  def clear_session_results(cls, session_id):
485
  """清除特定 session 的結果"""
486
  if session_id in cls._session_results:
487
+ del cls._session_results[session_id]