shivapriyasom commited on
Commit
c744750
Β·
verified Β·
1 Parent(s): 2e034b5

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +88 -76
inference.py CHANGED
@@ -501,25 +501,41 @@ def create_all_shap_plots(user_inputs, max_display=10):
501
 
502
 
503
  # ---------------------------------------------------------------------------
504
- # Icon array β€” proper stick figures
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  # ---------------------------------------------------------------------------
506
 
507
- def _stick_figure_shapes(cx, cy, color, s=0.42):
508
  """
509
- Draw a classic stick figure centred at (cx, cy) with scale s.
510
-
511
- Parts:
512
- - head : circle
513
- - spine : vertical line from neck to hips
514
- - arms : angled lines left/right from mid-spine
515
- - legs : angled lines left/right from hips
 
 
 
516
  """
517
  shapes = []
518
- lw = dict(color=color, width=max(1, s * 4)) # line width scales with size
519
 
520
- # ── head ──────────────────────────────────────────────────────────────
521
- hr = s * 0.22 # head radius
522
- hy = cy + s * 0.60 # head centre y
523
  shapes.append(dict(
524
  type="circle", xref="x", yref="y",
525
  x0=cx - hr, y0=hy - hr,
@@ -528,54 +544,29 @@ def _stick_figure_shapes(cx, cy, color, s=0.42):
528
  line=dict(color=color, width=0),
529
  ))
530
 
531
- # key y levels
532
- neck_y = cy + s * 0.35 # base of head / top of spine
533
- hip_y = cy - s * 0.20 # bottom of spine / top of legs
534
- arm_y = cy + s * 0.15 # where arms branch from spine
535
- foot_y = cy - s * 0.60 # feet
536
-
537
- # ── spine ─────────────────────────────────────────────────────────────
538
- shapes.append(dict(
539
- type="line", xref="x", yref="y",
540
- x0=cx, y0=neck_y,
541
- x1=cx, y1=hip_y,
542
- line=lw,
543
- ))
544
-
545
- # ── arms (angled outward) ──────────────────────────────────────────────
546
- arm_dx = s * 0.35
547
- arm_dy = s * 0.18
548
- # left arm
549
- shapes.append(dict(
550
- type="line", xref="x", yref="y",
551
- x0=cx, y0=arm_y,
552
- x1=cx - arm_dx, y1=arm_y - arm_dy,
553
- line=lw,
554
- ))
555
- # right arm
556
- shapes.append(dict(
557
- type="line", xref="x", yref="y",
558
- x0=cx, y0=arm_y,
559
- x1=cx + arm_dx, y1=arm_y - arm_dy,
560
- line=lw,
561
- ))
562
-
563
- # ── legs (angled outward) ──────────────────────────────────────────────
564
- leg_dx = s * 0.28
565
- # left leg
566
- shapes.append(dict(
567
- type="line", xref="x", yref="y",
568
- x0=cx, y0=hip_y,
569
- x1=cx - leg_dx, y1=foot_y,
570
- line=lw,
571
- ))
572
- # right leg
573
- shapes.append(dict(
574
- type="line", xref="x", yref="y",
575
- x0=cx, y0=hip_y,
576
- x1=cx + leg_dx, y1=foot_y,
577
- line=lw,
578
- ))
579
 
580
  return shapes
581
 
@@ -591,20 +582,36 @@ def icon_array(probability, outcome):
591
  }
592
 
593
  event_label, no_event_label = outcome_labels.get(outcome, ("Event", "No Event"))
594
- n_total = 100
595
- n_event = round(probability * n_total)
596
- n_no_event = n_total - n_event
597
  cols, rows = 10, 10
598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  all_shapes = []
600
  icon_idx = 0
601
 
602
- for row in range(rows): # row 0 = top
603
- for col in range(cols): # col 0 = left
604
  color = "#e05555" if icon_idx < n_event else "#3bbfad"
605
- cx = col
606
- cy = rows - 1 - row # invert so row 0 is at top
607
- all_shapes.extend(_stick_figure_shapes(cx, cy, color, s=0.38))
608
  icon_idx += 1
609
 
610
  fig = go.Figure()
@@ -612,27 +619,32 @@ def icon_array(probability, outcome):
612
  title=dict(
613
  text=(
614
  f"<b>{OUTCOME_DESCRIPTIONS.get(outcome, outcome)}</b><br>"
615
- f"<span style='font-size:12px;color:#e05555'>β–  {event_label}: {n_event}%</span>"
 
616
  f"&nbsp;&nbsp;"
617
- f"<span style='font-size:12px;color:#3bbfad'>β–  {no_event_label}: {n_no_event}%</span>"
 
618
  ),
619
  x=0.5, xanchor="center",
620
  font=dict(size=14, color="black"),
621
  ),
622
  shapes=all_shapes,
623
  xaxis=dict(
624
- range=[-0.7, cols - 0.3],
625
  showgrid=False, zeroline=False, showticklabels=False,
626
  fixedrange=True,
627
  ),
628
  yaxis=dict(
629
- range=[-0.8, rows - 0.2],
630
  showgrid=False, zeroline=False, showticklabels=False,
631
  fixedrange=True,
 
 
 
632
  ),
633
- height=440,
634
- width=420,
635
- margin=dict(l=10, r=10, t=85, b=10),
636
  plot_bgcolor="white",
637
  paper_bgcolor="white",
638
  )
 
501
 
502
 
503
  # ---------------------------------------------------------------------------
504
+ # Icon array
505
+ # ---------------------------------------------------------------------------
506
+ # Root cause of previous gaps / distortion:
507
+ # Plotly shape coords are in DATA units. If px-per-data-unit differs on
508
+ # x vs y axes the circle head becomes an ellipse and spacing looks uneven.
509
+ #
510
+ # Fix:
511
+ # β€’ Use EQUAL axis spans on x and y (both = cols + 2*pad = 10.3)
512
+ # β€’ Set width and height so that usable pixels are EQUAL on both axes:
513
+ # usable_w = W - margin_l - margin_r = W - 20
514
+ # usable_h = H - margin_t - margin_b = H - 100
515
+ # usable_w == usable_h β†’ H = W + 80
516
+ # β€’ This guarantees 1 data-unit = same number of pixels on both axes,
517
+ # so circles are round and spacing is perfectly uniform.
518
  # ---------------------------------------------------------------------------
519
 
520
+ def _stick_figure(cx, cy, color, s):
521
  """
522
+ Returns Plotly shape dicts for a stick figure centred at (cx, cy).
523
+ s = scale (data units). With a cell size of 1.0, s β‰ˆ 0.46 gives
524
+ a figure that fills ~75 % of the cell vertically.
525
+
526
+ Anatomy (all offsets relative to cy):
527
+ head centre : cy + s*0.55 radius s*0.18
528
+ neck top : cy + s*0.35
529
+ hip : cy - s*0.15
530
+ arm branch : cy + s*0.18
531
+ foot : cy - s*0.55
532
  """
533
  shapes = []
534
+ lw = dict(color=color, width=1.8) # fixed pixel width β€” looks consistent
535
 
536
+ # head
537
+ hr = s * 0.18
538
+ hy = cy + s * 0.55
539
  shapes.append(dict(
540
  type="circle", xref="x", yref="y",
541
  x0=cx - hr, y0=hy - hr,
 
544
  line=dict(color=color, width=0),
545
  ))
546
 
547
+ neck_y = cy + s * 0.35
548
+ hip_y = cy - s * 0.15
549
+ arm_y = cy + s * 0.18
550
+ foot_y = cy - s * 0.55
551
+
552
+ # spine
553
+ shapes.append(dict(type="line", xref="x", yref="y",
554
+ x0=cx, y0=neck_y, x1=cx, y1=hip_y, line=lw))
555
+
556
+ # arms
557
+ adx = s * 0.32
558
+ ady = s * 0.15
559
+ shapes.append(dict(type="line", xref="x", yref="y",
560
+ x0=cx, y0=arm_y, x1=cx - adx, y1=arm_y - ady, line=lw))
561
+ shapes.append(dict(type="line", xref="x", yref="y",
562
+ x0=cx, y0=arm_y, x1=cx + adx, y1=arm_y - ady, line=lw))
563
+
564
+ # legs
565
+ ldx = s * 0.26
566
+ shapes.append(dict(type="line", xref="x", yref="y",
567
+ x0=cx, y0=hip_y, x1=cx - ldx, y1=foot_y, line=lw))
568
+ shapes.append(dict(type="line", xref="x", yref="y",
569
+ x0=cx, y0=hip_y, x1=cx + ldx, y1=foot_y, line=lw))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
 
571
  return shapes
572
 
 
582
  }
583
 
584
  event_label, no_event_label = outcome_labels.get(outcome, ("Event", "No Event"))
585
+ n_event = round(probability * 100)
586
+ n_no_event = 100 - n_event
 
587
  cols, rows = 10, 10
588
 
589
+ # ── Layout constants ──────────────────────────────────────────────────
590
+ # Icons sit on an integer grid 0..9 Γ— 0..9.
591
+ # Padding of 0.65 on each side β†’ axis span = 9 + 2*0.65 = 10.30
592
+ # Margins: left=10, right=10, top=95, bottom=10
593
+ # usable_w = W - 20 ; usable_h = H - 105
594
+ # To ensure px_per_unit identical on both axes: usable_w == usable_h
595
+ # β†’ H = W + 85
596
+ # We also enforce equal axis spans (both 10.30).
597
+
598
+ PAD = 0.65
599
+ W = 400
600
+ H = W + 85 # = 485 β†’ usable = 380 px on both axes
601
+ S = 0.46 # figure scale (β‰ˆ 75 % vertical fill per cell)
602
+
603
+ x_lo, x_hi = -PAD, (cols - 1) + PAD # -0.65 … 9.65 span=10.30
604
+ y_lo, y_hi = -PAD, (rows - 1) + PAD # -0.65 … 9.65 span=10.30
605
+
606
  all_shapes = []
607
  icon_idx = 0
608
 
609
+ for row in range(rows): # row 0 β†’ top of grid
610
+ for col in range(cols): # col 0 β†’ left
611
  color = "#e05555" if icon_idx < n_event else "#3bbfad"
612
+ cx = col
613
+ cy = (rows - 1) - row # invert: row 0 β†’ cy=9 (top)
614
+ all_shapes.extend(_stick_figure(cx, cy, color, S))
615
  icon_idx += 1
616
 
617
  fig = go.Figure()
 
619
  title=dict(
620
  text=(
621
  f"<b>{OUTCOME_DESCRIPTIONS.get(outcome, outcome)}</b><br>"
622
+ f"<span style='font-size:12px;color:#e05555'>"
623
+ f"β–  {event_label}: {n_event}%</span>"
624
  f"&nbsp;&nbsp;"
625
+ f"<span style='font-size:12px;color:#3bbfad'>"
626
+ f"β–  {no_event_label}: {n_no_event}%</span>"
627
  ),
628
  x=0.5, xanchor="center",
629
  font=dict(size=14, color="black"),
630
  ),
631
  shapes=all_shapes,
632
  xaxis=dict(
633
+ range=[x_lo, x_hi],
634
  showgrid=False, zeroline=False, showticklabels=False,
635
  fixedrange=True,
636
  ),
637
  yaxis=dict(
638
+ range=[y_lo, y_hi],
639
  showgrid=False, zeroline=False, showticklabels=False,
640
  fixedrange=True,
641
+ # scaleanchor / scaleratio intentionally OMITTED β€”
642
+ # equal spans + equal usable pixels already guarantee
643
+ # identical px/unit on both axes without distortion.
644
  ),
645
+ width=W,
646
+ height=H,
647
+ margin=dict(l=10, r=10, t=95, b=10),
648
  plot_bgcolor="white",
649
  paper_bgcolor="white",
650
  )