Wen1201 commited on
Commit
52fc51b
·
verified ·
1 Parent(s): de34a92

Upload bayesian_core.py

Browse files
Files changed (1) hide show
  1. bayesian_core.py +31 -114
bayesian_core.py CHANGED
@@ -350,137 +350,54 @@ class BayesianHierarchicalAnalyzer:
350
  }
351
 
352
  def get_model_graph(self):
353
- """
354
- 生成模型 DAG 圖(返回 graphviz 物件)
355
- 使用正確的 WinBUGS/BUGS 標準風格
356
-
357
- Returns:
358
- graphviz.Digraph 物件
359
- """
360
  if self.model is None:
361
  raise ValueError("請先執行分析")
362
 
363
  try:
364
- import graphviz
 
 
 
365
 
366
- # 獲取欄位名稱
367
  control_prefix = getattr(self, 'col_control_win', 'control').replace('_win', '').replace('_battles', '').replace('_total', '')
368
  treatment_prefix = getattr(self, 'col_treatment_win', 'treatment').replace('_win', '').replace('_battles', '').replace('_total', '')
369
  control_win = getattr(self, 'col_control_win', 'control_win')
370
  treatment_win = getattr(self, 'col_treatment_win', 'treatment_win')
371
- control_total = getattr(self, 'col_control_total', 'control_total')
372
- treatment_total = getattr(self, 'col_treatment_total', 'treatment_total')
373
-
374
- # 標籤定義 (純英文,使用上標)
375
- labels = {
376
- # 隨機變數 (圓形)
377
- 'd': 'd',
378
- 'tau': 'tau',
379
- 'sigma': 'sigma',
380
- 'mu': 'mu_i',
381
- 'delta': 'delta_i',
382
- 'delta_new': 'delta_new',
383
- f'p_{control_prefix}': f'p_{control_prefix}_i',
384
- f'p_{treatment_prefix}': f'p_{treatment_prefix}_i',
385
- 'or_speed': 'OR',
386
-
387
- # 觀測資料 (矩形)
388
- f'{control_win}_obs': f'r_{control_prefix}_i',
389
- f'{treatment_win}_obs': f'r_{treatment_prefix}_i',
390
- }
391
-
392
- # 使用 PyMC 生成基本圖
393
- gv_original = pm.model_to_graphviz(self.model)
394
-
395
- # 創建新圖 - WinBUGS 標準風格
396
- dot = graphviz.Digraph(comment='Bayesian Hierarchical Model')
397
- dot.attr(rankdir='TB')
398
- dot.attr('node', fontname='Times-Italic', fontsize='16')
399
- dot.attr('edge', color='black', penwidth='1.2', arrowsize='0.7')
400
-
401
- # WinBUGS 標準: 所有變數都是圓形 (除了觀測資料)
402
- circular_nodes = [
403
- 'd', 'tau', 'sigma', 'mu', 'delta', 'delta_new',
404
- f'p_{control_prefix}', f'p_{treatment_prefix}', 'or_speed'
405
- ]
406
 
407
- # 矩形:有觀測資料
408
- observed_nodes = [f'{control_win}_obs', f'{treatment_win}_obs']
409
 
410
- # 解析並重建
411
- lines = gv_original.source.split('\n')
412
-
413
- for line in lines:
414
- line = line.strip()
415
- if not line or line.startswith('digraph') or line == '}':
416
- continue
417
-
418
- # 節點定義
419
- if '[label=' in line:
420
- node = line.split('[')[0].strip()
421
- label = labels.get(node, node)
422
-
423
- if node in circular_nodes:
424
- # WinBUGS: 所有變數 = 圓形,白底,黑框
425
- dot.node(node, label,
426
- shape='circle',
427
- style='filled',
428
- fillcolor='white',
429
- color='black',
430
- penwidth='2.0',
431
- width='1.0',
432
- height='1.0',
433
- fixedsize='true')
434
-
435
- elif node in observed_nodes:
436
- # WinBUGS: 觀測資料 = 矩形,白底,黑框
437
- dot.node(node, label,
438
- shape='box',
439
- style='filled',
440
- fillcolor='white',
441
- color='black',
442
- penwidth='2.0',
443
- width='1.2',
444
- height='0.8')
445
-
446
- # 邊定義
447
- elif '->' in line:
448
- parts = line.replace(';', '').split('->')
449
- if len(parts) == 2:
450
- src = parts[0].strip()
451
- dst = parts[1].strip().split('[')[0].strip()
452
- dot.edge(src, dst)
453
-
454
- # 手動添加固定常數節點 (雙層矩形)
455
- dot.node(f'{control_total}_const', f'n_{control_prefix}_i',
456
- shape='box',
457
- style='filled',
458
- fillcolor='white',
459
- color='black',
460
- penwidth='4.0',
461
- width='1.2',
462
- height='0.8',
463
- peripheries='2') # 雙層框
464
 
465
- dot.node(f'{treatment_total}_const', f'n_{treatment_prefix}_i',
466
- shape='box',
467
- style='filled',
468
- fillcolor='white',
469
- color='black',
470
- penwidth='4.0',
471
- width='1.2',
472
- height='0.8',
473
- peripheries='2') # 雙層框
474
 
475
- # 添加常數到觀測值的連接
476
- dot.edge(f'{control_total}_const', f'{control_win}_obs')
477
- dot.edge(f'{treatment_total}_const', f'{treatment_win}_obs')
478
 
479
- return dot
480
 
481
  except Exception as e:
482
  raise Exception(f"無法生成 DAG 圖: {str(e)}")
483
-
 
484
  def get_session_results(cls, session_id):
485
  """獲取特定 session 的結果"""
486
  return cls._session_results.get(session_id)
 
350
  }
351
 
352
  def get_model_graph(self):
353
+ """生成模型 DAG 圖(返回 graphviz 物件)"""
 
 
 
 
 
 
354
  if self.model is None:
355
  raise ValueError("請先執行分析")
356
 
357
  try:
358
+ import re
359
+
360
+ # 使用 PyMC 原本的圖
361
+ gv = pm.model_to_graphviz(self.model)
362
 
363
+ # 獲取欄位資訊 (動態)
364
  control_prefix = getattr(self, 'col_control_win', 'control').replace('_win', '').replace('_battles', '').replace('_total', '')
365
  treatment_prefix = getattr(self, 'col_treatment_win', 'treatment').replace('_win', '').replace('_battles', '').replace('_total', '')
366
  control_win = getattr(self, 'col_control_win', 'control_win')
367
  treatment_win = getattr(self, 'col_treatment_win', 'treatment_win')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
+ # 替換標籤 (改文字,不改圖結構)
370
+ source = gv.source
371
 
372
+ # 替換模型變數的標籤
373
+ replacements = {
374
+ # PyMC 變數名 -> 顯示名稱
375
+ f'label="{f"p_{control_prefix}"}"': f'label="p_{control_prefix}[i]\\n({control_prefix.capitalize()} Win Rate)"',
376
+ f'label="{f"p_{treatment_prefix}"}"': f'label="p_{treatment_prefix}[i]\\n({treatment_prefix.capitalize()} Win Rate)"',
377
+ f'label="{control_win}_obs"': f'label="{control_win}_obs[i]\\n({control_prefix.capitalize()} Wins)"',
378
+ f'label="{treatment_win}_obs"': f'label="{treatment_win}_obs[i]\\n({treatment_prefix.capitalize()} Wins)"',
379
+ 'label="d"': f'label="d\\n(Overall Effect\\n{treatment_prefix} vs {control_prefix})"',
380
+ 'label="tau"': 'label="tau\\n(Precision)"',
381
+ 'label="sigma"': 'label="sigma\\n(Between-Pair SD)"',
382
+ 'label="mu"': 'label="mu[i]\\n(Baseline)"',
383
+ 'label="delta"': f'label="delta[i]\\n({treatment_prefix.capitalize()} Advantage)"',
384
+ 'label="delta_new"': 'label="delta_new\\n(New Pair)"',
385
+ 'label="or_speed"': 'label="OR\\n(Odds Ratio)"',
386
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
+ # 執行替換
389
+ for old, new in replacements.items():
390
+ source = source.replace(old, new)
 
 
 
 
 
 
391
 
392
+ # 更新 graphviz 物件
393
+ gv.source = source
 
394
 
395
+ return gv
396
 
397
  except Exception as e:
398
  raise Exception(f"無法生成 DAG 圖: {str(e)}")
399
+
400
+ @classmethod
401
  def get_session_results(cls, session_id):
402
  """獲取特定 session 的結果"""
403
  return cls._session_results.get(session_id)