lukeingawesome commited on
Commit
6410e2a
·
verified ·
1 Parent(s): 10679da

Upload modeling_chest2vec_labeler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_chest2vec_labeler.py +65 -0
modeling_chest2vec_labeler.py CHANGED
@@ -282,6 +282,71 @@ class Chest2VecLabelerModel(PreTrainedModel):
282
  return res
283
 
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  def report_f1(gt_reports: List[str], pred_reports: List[str], model=None, tokenizer=None,
286
  model_id: str = "chest2vec/chest2vec_labeler", **kw) -> Dict[str, Any]:
287
  """Convenience wrapper: load the labeler (if not supplied) and score GT vs predicted reports."""
 
282
  return res
283
 
284
 
285
+ # ---- per-label best F1 (threshold swept to maximize F1) vs ground-truth labels ----
286
+ def _to_positive_matrix(self, gt, names):
287
+ """Coerce ground-truth labels to a [N, len(names)] binary positive matrix.
288
+
289
+ Accepts a pandas DataFrame with the label columns (ternary 1/0/-1/NaN; positive == 1),
290
+ or a numpy/torch array (ternary -> ==1, or already-binary 0/1)."""
291
+ import numpy as np
292
+ try:
293
+ import pandas as pd
294
+ if isinstance(gt, pd.DataFrame):
295
+ out = np.zeros((len(gt), len(names)), dtype=int)
296
+ for j, c in enumerate(names):
297
+ if c in gt.columns:
298
+ out[:, j] = (pd.to_numeric(gt[c], errors="coerce").fillna(0).values == 1).astype(int)
299
+ return out
300
+ except ImportError:
301
+ pass
302
+ arr = gt.detach().cpu().numpy() if hasattr(gt, "detach") else np.asarray(gt)
303
+ return (arr == 1).astype(int)
304
+
305
+ @torch.no_grad()
306
+ def per_label_best_f1(self, reports: List[str], gt, tokenizer=None, level: str = "leaf",
307
+ min_pos: int = 30, batch_size: int = 16, max_len: Optional[int] = None,
308
+ device=None) -> Dict[str, Any]:
309
+ """
310
+ For each label, sweep the decision threshold and report the **F1-maximizing** operating
311
+ point (best F1 + the threshold that achieves it), evaluated against ground-truth labels.
312
+
313
+ `gt` is a ground-truth label matrix for `reports` (DataFrame with the 137 label columns,
314
+ or array). `level` is "leaf" / "upper" / "anatomy". Returns per-label best F1 / threshold /
315
+ n_pos, plus macro best-F1 over all labels and over labels with >= `min_pos` positives.
316
+ """
317
+ import numpy as np
318
+ from sklearn.metrics import precision_recall_curve
319
+ leaf_names = list(self.config.labels)
320
+ gt_leaf = self._to_positive_matrix(gt, leaf_names)
321
+ pr_leaf = self.predict_proba(reports, tokenizer=tokenizer, batch_size=batch_size,
322
+ max_len=max_len, device=device).numpy()
323
+ if level == "leaf":
324
+ prob, names, gtb = pr_leaf, leaf_names, gt_leaf
325
+ else:
326
+ pu, un, pa, an = self.aggregate_hierarchy(pr_leaf)
327
+ gu, _, ga, _ = self.aggregate_hierarchy(gt_leaf.astype(np.float32))
328
+ prob, names, gtb = (pu, un, (gu >= 0.5).astype(int)) if level == "upper" else (pa, an, (ga >= 0.5).astype(int))
329
+
330
+ per: Dict[str, Any] = {}
331
+ all_best, ge_best = [], []
332
+ for j, lab in enumerate(names):
333
+ t = gtb[:, j].astype(int); s = prob[:, j].astype(float); npos = int(t.sum())
334
+ if npos == 0 or len(np.unique(t)) < 2:
335
+ bf, bt = 0.0, None
336
+ else:
337
+ p, r, thr = precision_recall_curve(t, s)
338
+ f1 = (2 * p * r / (p + r + 1e-12))[:-1]
339
+ bi = int(np.nanargmax(f1)); bf = float(f1[bi]); bt = float(thr[bi])
340
+ per[lab] = {"best_f1": bf, "best_threshold": bt, "n_pos": npos}
341
+ all_best.append(bf)
342
+ if npos >= min_pos:
343
+ ge_best.append(bf)
344
+ return {"level": level, "min_pos": min_pos,
345
+ "macro_best_f1": float(np.mean(all_best)) if all_best else 0.0,
346
+ "macro_best_f1_min_pos": float(np.mean(ge_best)) if ge_best else 0.0,
347
+ "n_labels_min_pos": len(ge_best), "per_label": per}
348
+
349
+
350
  def report_f1(gt_reports: List[str], pred_reports: List[str], model=None, tokenizer=None,
351
  model_id: str = "chest2vec/chest2vec_labeler", **kw) -> Dict[str, Any]:
352
  """Convenience wrapper: load the labeler (if not supplied) and score GT vs predicted reports."""