Zeqh Claude Opus 4.8 commited on
Commit
b091c09
·
1 Parent(s): e99bcfe

Confidence heatmap + drop v1 models from dropdown

Browse files

- Model dropdown: only the two dataset-4 retrains (bertv2, robertav2)
- model.py: expose per-token softmax confidence, mean over merged entities
- viz.py + Live Parser: 'shade by confidence' heatmap toggle + debug panel
- Persist parse results across reruns

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

Files changed (4) hide show
  1. config.py +1 -6
  2. lib/model.py +10 -4
  3. lib/viz.py +86 -15
  4. pages/1_Live_Parser.py +32 -3
config.py CHANGED
@@ -16,10 +16,7 @@ import os
16
  # Hugging Face owner + the canonical "best model" repo the app loads by default
17
  # and that the Manage Model page overwrites.
18
  HF_OWNER = os.environ.get("HF_OWNER", "Zeqhx")
19
- # Four published models, selectable in the UI:
20
- # v1 = earlier (dataset-2) models, v2 = latest (dataset-4) retrains.
21
- BERT_V1_ID = os.environ.get("DASHBOARD_BERT_V1_ID", f"{HF_OWNER}/cv-parser-bert-v1")
22
- ROBERTA_V1_ID = os.environ.get("DASHBOARD_ROBERTA_V1_ID", f"{HF_OWNER}/cv-parser-roberta-v1")
23
  BERT_V2_ID = os.environ.get("DASHBOARD_BERT_V2_ID", f"{HF_OWNER}/cv-parser-bert-v2")
24
  ROBERTA_V2_ID = os.environ.get("DASHBOARD_ROBERTA_V2_ID", f"{HF_OWNER}/cv-parser-roberta-v2")
25
 
@@ -36,8 +33,6 @@ DEMO_LABEL = "Demo — untrained roberta-base"
36
  # Toggle registry. Each entry: (label, kind, ref). "local" entries are only
37
  # offered when the folder exists (dev machines); "hub" entries are always offered.
38
  MODEL_REGISTRY = [
39
- ("bertv1", "hub", BERT_V1_ID),
40
- ("robertav1", "hub", ROBERTA_V1_ID),
41
  ("bertv2", "hub", BERT_V2_ID),
42
  ("robertav2", "hub", ROBERTA_V2_ID),
43
  ]
 
16
  # Hugging Face owner + the canonical "best model" repo the app loads by default
17
  # and that the Manage Model page overwrites.
18
  HF_OWNER = os.environ.get("HF_OWNER", "Zeqhx")
19
+ # Two published models, selectable in the UI. Both are dataset-4 retrains.
 
 
 
20
  BERT_V2_ID = os.environ.get("DASHBOARD_BERT_V2_ID", f"{HF_OWNER}/cv-parser-bert-v2")
21
  ROBERTA_V2_ID = os.environ.get("DASHBOARD_ROBERTA_V2_ID", f"{HF_OWNER}/cv-parser-roberta-v2")
22
 
 
33
  # Toggle registry. Each entry: (label, kind, ref). "local" entries are only
34
  # offered when the folder exists (dev machines); "hub" entries are always offered.
35
  MODEL_REGISTRY = [
 
 
36
  ("bertv2", "hub", BERT_V2_ID),
37
  ("robertav2", "hub", ROBERTA_V2_ID),
38
  ]
lib/model.py CHANGED
@@ -130,6 +130,8 @@ def predict(text: str, lm: LoadedModel):
130
  attn_dev = attn.to(lm.device)
131
 
132
  logits = lm.model(input_ids=input_ids, attention_mask=attn_dev).logits
 
 
133
  preds = logits.argmax(-1).cpu()
134
 
135
  # Deduplicate overlapping sliding-window tokens by their global char offset.
@@ -142,15 +144,15 @@ def predict(text: str, lm: LoadedModel):
142
  continue # special token or padding
143
  if s in seen:
144
  continue
145
- seen[s] = (s, e, int(preds[w][i]))
146
 
147
  tokens = []
148
  for s in sorted(seen):
149
- _, e, pid = seen[s]
150
  label = lm.id2label.get(pid, "O")
151
  etype = label.split("-", 1)[1] if "-" in label else None
152
  tokens.append({"text": text[s:e], "label": label, "type": etype,
153
- "start": s, "end": e})
154
 
155
  entities = _merge_bio(tokens, text)
156
  return tokens, entities
@@ -171,14 +173,18 @@ def _merge_bio(tokens, text):
171
  if prefix == "B" or cur is None or cur["type"] != etype:
172
  if cur:
173
  entities.append(cur)
174
- cur = {"type": etype, "start": t["start"], "end": t["end"]}
 
175
  else: # I- continuing the same type
176
  cur["end"] = t["end"]
 
177
  if cur:
178
  entities.append(cur)
179
 
180
  for e in entities:
181
  e["text"] = text[e["start"]:e["end"]].strip()
 
 
182
  return [e for e in entities if e["text"]]
183
 
184
 
 
130
  attn_dev = attn.to(lm.device)
131
 
132
  logits = lm.model(input_ids=input_ids, attention_mask=attn_dev).logits
133
+ probs = logits.softmax(-1) # keep the distribution, not just argmax
134
+ conf_all = probs.max(-1).values.cpu() # per-token confidence of the chosen label
135
  preds = logits.argmax(-1).cpu()
136
 
137
  # Deduplicate overlapping sliding-window tokens by their global char offset.
 
144
  continue # special token or padding
145
  if s in seen:
146
  continue
147
+ seen[s] = (s, e, int(preds[w][i]), float(conf_all[w][i]))
148
 
149
  tokens = []
150
  for s in sorted(seen):
151
+ _, e, pid, conf = seen[s]
152
  label = lm.id2label.get(pid, "O")
153
  etype = label.split("-", 1)[1] if "-" in label else None
154
  tokens.append({"text": text[s:e], "label": label, "type": etype,
155
+ "start": s, "end": e, "conf": conf})
156
 
157
  entities = _merge_bio(tokens, text)
158
  return tokens, entities
 
173
  if prefix == "B" or cur is None or cur["type"] != etype:
174
  if cur:
175
  entities.append(cur)
176
+ cur = {"type": etype, "start": t["start"], "end": t["end"],
177
+ "confs": [t.get("conf", 1.0)]}
178
  else: # I- continuing the same type
179
  cur["end"] = t["end"]
180
+ cur["confs"].append(t.get("conf", 1.0))
181
  if cur:
182
  entities.append(cur)
183
 
184
  for e in entities:
185
  e["text"] = text[e["start"]:e["end"]].strip()
186
+ confs = e.pop("confs", []) or [1.0]
187
+ e["conf"] = sum(confs) / len(confs) # mean confidence over member tokens
188
  return [e for e in entities if e["text"]]
189
 
190
 
lib/viz.py CHANGED
@@ -6,6 +6,50 @@ import html
6
  import config
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def _legend() -> str:
10
  items = []
11
  for t in config.ENTITY_TYPES:
@@ -17,48 +61,75 @@ def _legend() -> str:
17
  return '<div style="margin-bottom:10px">' + "".join(items) + "</div>"
18
 
19
 
20
- def render_entities_html(text: str, entities: list[dict]) -> str:
21
- """Original text with entity spans wrapped in coloured marks."""
 
 
 
 
 
22
  ents = sorted((e for e in entities if e["type"] in config.ENTITY_COLORS),
23
  key=lambda e: e["start"])
 
 
24
  out, cursor = [], 0
25
  for e in ents:
26
  if e["start"] < cursor: # skip any overlap defensively
27
  continue
28
  out.append(html.escape(text[cursor:e["start"]]))
29
- color = config.ENTITY_COLORS[e["type"]]
30
  label = config.ENTITY_LABELS[e["type"]]
 
 
 
 
 
 
 
31
  out.append(
32
- f'<mark style="background:{color};color:#fff;padding:1px 4px;'
33
- f'border-radius:4px" title="{label}">'
34
  f'{html.escape(text[e["start"]:e["end"]])}'
35
  f'<sub style="font-size:0.6em;opacity:.85"> {label}</sub></mark>'
36
  )
37
  cursor = e["end"]
38
  out.append(html.escape(text[cursor:]))
39
  body = "".join(out).replace("\n", "<br>")
40
- return (_legend() +
41
  f'<div style="line-height:2.1;font-family:system-ui;font-size:0.95rem;'
42
  f'border:1px solid #ddd;border-radius:8px;padding:16px;'
43
  f'max-height:520px;overflow:auto">{body}</div>')
44
 
45
 
46
- def render_tokens_html(tokens: list[dict], limit: int = 400) -> str:
47
- """Sub-word token chips, coloured by predicted label — the 'tokenization view'."""
 
 
 
 
 
 
 
 
 
48
  chips = []
49
- for t in tokens[:limit]:
50
  txt = html.escape(t["text"]) or "·"
51
- if t["type"] in config.ENTITY_COLORS:
52
- color = config.ENTITY_COLORS[t["type"]]
53
- style = f"background:{color};color:#fff"
 
 
 
 
54
  else:
55
  style = "background:#eee;color:#555"
56
  chips.append(
57
- f'<span style="{style};padding:2px 6px;border-radius:4px;margin:2px;'
58
- f'display:inline-block;font-family:monospace;font-size:0.8rem">{txt}</span>'
59
  )
60
  more = "" if len(tokens) <= limit else f'<div style="color:#888;margin-top:8px">… +{len(tokens)-limit} more tokens</div>'
61
- return (_legend() +
62
  f'<div style="border:1px solid #ddd;border-radius:8px;padding:12px;'
63
  f'max-height:420px;overflow:auto">{"".join(chips)}{more}</div>')
64
 
 
6
  import config
7
 
8
 
9
+ def _hex_to_rgba(hex_color: str, alpha: float) -> str:
10
+ """'#2a9d8f' + alpha -> 'rgba(42,157,143,0.83)'."""
11
+ h = hex_color.lstrip("#")
12
+ r, g, b = (int(h[i:i + 2], 16) for i in (0, 2, 4))
13
+ return f"rgba({r},{g},{b},{alpha:.2f})"
14
+
15
+
16
+ # --- Confidence heatmap ramp: low (red) -> mid (orange) -> high (teal) -------
17
+ _RAMP_LO = (214, 40, 40) # low confidence
18
+ _RAMP_MID = (244, 162, 97) # mid
19
+ _RAMP_HI = (42, 157, 143) # high confidence
20
+
21
+
22
+ def _lerp(a: tuple, b: tuple, t: float) -> tuple:
23
+ return tuple(round(a[i] + (b[i] - a[i]) * t) for i in range(3))
24
+
25
+
26
+ def _conf_color(t: float) -> str:
27
+ """0..1 -> rgb on the red->orange->teal ramp."""
28
+ t = max(0.0, min(1.0, t))
29
+ r, g, b = (_lerp(_RAMP_LO, _RAMP_MID, t / 0.5) if t < 0.5
30
+ else _lerp(_RAMP_MID, _RAMP_HI, (t - 0.5) / 0.5))
31
+ return f"rgb({r},{g},{b})"
32
+
33
+
34
+ def _stretch(conf: float, lo: float, hi: float) -> float:
35
+ """Contrast-stretch a confidence into 0..1 across the observed [lo, hi]."""
36
+ if hi - lo < 1e-6:
37
+ return 1.0
38
+ return (conf - lo) / (hi - lo)
39
+
40
+
41
+ def _conf_legend(lo: float, hi: float) -> str:
42
+ grad = (f"linear-gradient(to right, rgb{_RAMP_LO}, rgb{_RAMP_MID}, rgb{_RAMP_HI})")
43
+ return (
44
+ '<div style="margin-bottom:10px;font-size:0.8rem;color:#444">'
45
+ f'<span style="margin-right:8px">low ({lo:.0%})</span>'
46
+ f'<span style="display:inline-block;width:180px;height:12px;background:{grad};'
47
+ 'border-radius:3px;vertical-align:middle"></span>'
48
+ f'<span style="margin-left:8px">high ({hi:.0%}) — colour stretched across this CV</span>'
49
+ '</div>'
50
+ )
51
+
52
+
53
  def _legend() -> str:
54
  items = []
55
  for t in config.ENTITY_TYPES:
 
61
  return '<div style="margin-bottom:10px">' + "".join(items) + "</div>"
62
 
63
 
64
+ def render_entities_html(text: str, entities: list[dict],
65
+ shade_by_conf: bool = False) -> str:
66
+ """Original text with entity spans wrapped in coloured marks.
67
+
68
+ When ``shade_by_conf`` is set, each mark's background opacity reflects the
69
+ model's mean confidence for that entity, and the % is shown in the tooltip.
70
+ """
71
  ents = sorted((e for e in entities if e["type"] in config.ENTITY_COLORS),
72
  key=lambda e: e["start"])
73
+ cvals = [e.get("conf", 1.0) for e in ents] or [1.0]
74
+ lo, hi = min(cvals), max(cvals)
75
  out, cursor = [], 0
76
  for e in ents:
77
  if e["start"] < cursor: # skip any overlap defensively
78
  continue
79
  out.append(html.escape(text[cursor:e["start"]]))
80
+ hex_color = config.ENTITY_COLORS[e["type"]]
81
  label = config.ENTITY_LABELS[e["type"]]
82
+ conf = e.get("conf", 1.0)
83
+ if shade_by_conf:
84
+ bg = _conf_color(_stretch(conf, lo, hi))
85
+ title = f"{label} · {conf:.0%} confidence"
86
+ else:
87
+ bg = hex_color
88
+ title = label
89
  out.append(
90
+ f'<mark style="background:{bg};color:#fff;padding:1px 4px;'
91
+ f'border-radius:4px" title="{title}">'
92
  f'{html.escape(text[e["start"]:e["end"]])}'
93
  f'<sub style="font-size:0.6em;opacity:.85"> {label}</sub></mark>'
94
  )
95
  cursor = e["end"]
96
  out.append(html.escape(text[cursor:]))
97
  body = "".join(out).replace("\n", "<br>")
98
+ return ((_conf_legend(lo, hi) if shade_by_conf else _legend()) +
99
  f'<div style="line-height:2.1;font-family:system-ui;font-size:0.95rem;'
100
  f'border:1px solid #ddd;border-radius:8px;padding:16px;'
101
  f'max-height:520px;overflow:auto">{body}</div>')
102
 
103
 
104
+ def render_tokens_html(tokens: list[dict], limit: int = 400,
105
+ shade_by_conf: bool = False) -> str:
106
+ """Sub-word token chips, coloured by predicted label — the 'tokenization view'.
107
+
108
+ When ``shade_by_conf`` is set, each chip's background opacity reflects the
109
+ model's confidence in that token's label (low-confidence chips fade out),
110
+ and the exact % shows on hover.
111
+ """
112
+ shown = tokens[:limit]
113
+ cvals = [t.get("conf", 1.0) for t in shown] or [1.0]
114
+ lo, hi = min(cvals), max(cvals)
115
  chips = []
116
+ for t in shown:
117
  txt = html.escape(t["text"]) or "·"
118
+ conf = t.get("conf", 1.0)
119
+ title = f"{t['label']} · {conf:.0%}"
120
+ if shade_by_conf:
121
+ # Pure confidence heatmap: every token coloured on the ramp.
122
+ style = f"background:{_conf_color(_stretch(conf, lo, hi))};color:#fff"
123
+ elif t["type"] in config.ENTITY_COLORS:
124
+ style = f"background:{config.ENTITY_COLORS[t['type']]};color:#fff"
125
  else:
126
  style = "background:#eee;color:#555"
127
  chips.append(
128
+ f'<span title="{title}" style="{style};padding:2px 6px;border-radius:4px;'
129
+ f'margin:2px;display:inline-block;font-family:monospace;font-size:0.8rem">{txt}</span>'
130
  )
131
  more = "" if len(tokens) <= limit else f'<div style="color:#888;margin-top:8px">… +{len(tokens)-limit} more tokens</div>'
132
+ return ((_conf_legend(lo, hi) if shade_by_conf else _legend()) +
133
  f'<div style="border:1px solid #ddd;border-radius:8px;padding:12px;'
134
  f'max-height:420px;overflow:auto">{"".join(chips)}{more}</div>')
135
 
pages/1_Live_Parser.py CHANGED
@@ -35,11 +35,15 @@ else:
35
 
36
  run = st.button("Parse CV", type="primary", disabled=not text.strip())
37
 
38
- # ---- Output -----------------------------------------------------------------
39
  if run and text.strip():
40
  with st.spinner("Tokenizing and classifying…"):
41
  tokens, entities = predict(text, lm)
 
42
 
 
 
 
 
43
  grouped = group_entities(entities)
44
  c1, c2, c3, c4 = st.columns(4)
45
  c1.metric("Sub-word tokens", len(tokens))
@@ -47,16 +51,41 @@ if run and text.strip():
47
  c3.metric("Skills", len(grouped["SKILL"]))
48
  c4.metric("Education", len(grouped["EDUCATION"]))
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  tab_ent, tab_tok, tab_card = st.tabs(
51
  ["🏷️ Highlighted entities", "🔢 Tokenization view", "🗂️ Structured summary"])
52
 
53
  with tab_ent:
54
- st.markdown(viz.render_entities_html(text, entities), unsafe_allow_html=True)
 
55
 
56
  with tab_tok:
57
  st.caption("Each chip is one sub-word token produced by the tokenizer, "
58
  "coloured by its predicted label.")
59
- st.markdown(viz.render_tokens_html(tokens), unsafe_allow_html=True)
 
60
 
61
  with tab_card:
62
  cols = st.columns(3)
 
35
 
36
  run = st.button("Parse CV", type="primary", disabled=not text.strip())
37
 
 
38
  if run and text.strip():
39
  with st.spinner("Tokenizing and classifying…"):
40
  tokens, entities = predict(text, lm)
41
+ st.session_state["parse"] = {"text": text, "tokens": tokens, "entities": entities}
42
 
43
+ # ---- Output (persists across reruns, e.g. when flipping the toggle) ---------
44
+ parse = st.session_state.get("parse")
45
+ if parse:
46
+ text, tokens, entities = parse["text"], parse["tokens"], parse["entities"]
47
  grouped = group_entities(entities)
48
  c1, c2, c3, c4 = st.columns(4)
49
  c1.metric("Sub-word tokens", len(tokens))
 
51
  c3.metric("Skills", len(grouped["SKILL"]))
52
  c4.metric("Education", len(grouped["EDUCATION"]))
53
 
54
+ shade = st.toggle(
55
+ "🌡️ Shade by confidence",
56
+ value=False,
57
+ help="Fade each entity / token by how confident the model is in its label "
58
+ "(softmax probability). Hover any span to see the exact %.",
59
+ )
60
+
61
+ with st.expander("🐞 Debug — toggle state & confidence distribution"):
62
+ confs = [t.get("conf", 1.0) for t in tokens]
63
+ st.write({
64
+ "toggle `shade` value": shade,
65
+ "n_tokens": len(tokens),
66
+ "conf min": round(min(confs), 4) if confs else None,
67
+ "conf mean": round(sum(confs) / len(confs), 4) if confs else None,
68
+ "conf max": round(max(confs), 4) if confs else None,
69
+ })
70
+ st.caption("If min≈max≈1.0, the model is uniformly confident and shading is "
71
+ "visually subtle by design — not a bug. Entity confidences:")
72
+ st.table([
73
+ {"entity": e["text"], "type": e["type"], "conf %": f"{e.get('conf', 1.0):.1%}"}
74
+ for e in entities
75
+ ])
76
+
77
  tab_ent, tab_tok, tab_card = st.tabs(
78
  ["🏷️ Highlighted entities", "🔢 Tokenization view", "🗂️ Structured summary"])
79
 
80
  with tab_ent:
81
+ st.markdown(viz.render_entities_html(text, entities, shade_by_conf=shade),
82
+ unsafe_allow_html=True)
83
 
84
  with tab_tok:
85
  st.caption("Each chip is one sub-word token produced by the tokenizer, "
86
  "coloured by its predicted label.")
87
+ st.markdown(viz.render_tokens_html(tokens, shade_by_conf=shade),
88
+ unsafe_allow_html=True)
89
 
90
  with tab_card:
91
  cols = st.columns(3)