wchen22 commited on
Commit
472c58c
·
verified ·
1 Parent(s): 87c9bb7

feat: add compression classifier score calibration

Browse files
Files changed (2) hide show
  1. README.md +3 -2
  2. app.py +22 -6
README.md CHANGED
@@ -48,8 +48,9 @@ Live Space:
48
  - Mount `classifier_manifest.json`, tokenizer files, and optional `model.onnx`;
49
  set `TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR` to let the Space use artifact DROP
50
  labels through ONNX Runtime or the manifest fallback. ONNX labels are
51
- evaluated in chunked windows using manifest `max_length` and `stride`; labels
52
- still pass through protected-span and deletion-only safety gates.
 
53
 
54
  Deploy:
55
 
 
48
  - Mount `classifier_manifest.json`, tokenizer files, and optional `model.onnx`;
49
  set `TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR` to let the Space use artifact DROP
50
  labels through ONNX Runtime or the manifest fallback. ONNX labels are
51
+ evaluated in chunked windows using manifest `max_length` and `stride`; mounted
52
+ ONNX labels expose `keep_score`, `drop_score`, and `drop_score_threshold`.
53
+ DROP spans still pass through protected-span and deletion-only safety gates.
54
 
55
  Deploy:
56
 
app.py CHANGED
@@ -207,29 +207,37 @@ def _onnx_labels(
207
  str(key): value
208
  for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items()
209
  }
 
 
 
 
210
  best_by_span: dict[tuple[int, int], dict[str, Any]] = {}
211
  for chunk_index, chunk_logits in enumerate(logits):
212
  tokens = tokenizer.convert_ids_to_tokens(input_ids[chunk_index])
213
  for token_index, token_logits in enumerate(chunk_logits):
214
  probs = _softmax(np.asarray(token_logits, dtype=float).tolist())
215
- label_id = int(np.argmax(probs))
 
 
216
  start, end = offsets[chunk_index][token_index]
217
  start = int(start)
218
  end = int(end)
219
  if end <= start:
220
  continue
221
- score = round(float(probs[label_id]), 6)
222
  key = (start, end)
223
  item = {
224
  "token": tokens[token_index],
225
  "label": id2label.get(str(label_id), "KEEP"),
226
- "score": score,
 
 
 
227
  "start": start,
228
  "end": end,
229
  "source": "onnx_token_classifier",
230
  "chunk_index": chunk_index,
231
  }
232
- if score > float(best_by_span.get(key, {}).get("score", -1.0)):
233
  best_by_span[key] = item
234
  return [best_by_span[key] for key in sorted(best_by_span)]
235
 
@@ -245,13 +253,21 @@ def _safe_classifier_drop_ranges(
245
  drop_labels = 0
246
  blocked = 0
247
  for item in labels:
248
- if str(item.get("label") or item.get("entity") or "").upper() != "DROP":
 
 
 
 
 
 
 
 
249
  continue
250
  drop_labels += 1
251
  try:
252
  start = int(item["start"])
253
  end = int(item["end"])
254
- score = float(item.get("score", 1.0))
255
  except Exception:
256
  blocked += 1
257
  continue
 
207
  str(key): value
208
  for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items()
209
  }
210
+ label2id = {str(value).upper(): int(key) for key, value in id2label.items()}
211
+ keep_id = label2id.get("KEEP", 0)
212
+ drop_id = label2id.get("DROP", 1)
213
+ drop_score_threshold = float(manifest.get("drop_score_threshold", 0.5))
214
  best_by_span: dict[tuple[int, int], dict[str, Any]] = {}
215
  for chunk_index, chunk_logits in enumerate(logits):
216
  tokens = tokenizer.convert_ids_to_tokens(input_ids[chunk_index])
217
  for token_index, token_logits in enumerate(chunk_logits):
218
  probs = _softmax(np.asarray(token_logits, dtype=float).tolist())
219
+ keep_score = float(probs[keep_id]) if keep_id < len(probs) else 0.0
220
+ drop_score = float(probs[drop_id]) if drop_id < len(probs) else 0.0
221
+ label_id = drop_id if drop_score >= drop_score_threshold else keep_id
222
  start, end = offsets[chunk_index][token_index]
223
  start = int(start)
224
  end = int(end)
225
  if end <= start:
226
  continue
 
227
  key = (start, end)
228
  item = {
229
  "token": tokens[token_index],
230
  "label": id2label.get(str(label_id), "KEEP"),
231
+ "score": round(drop_score if label_id == drop_id else keep_score, 6),
232
+ "keep_score": round(keep_score, 6),
233
+ "drop_score": round(drop_score, 6),
234
+ "drop_score_threshold": drop_score_threshold,
235
  "start": start,
236
  "end": end,
237
  "source": "onnx_token_classifier",
238
  "chunk_index": chunk_index,
239
  }
240
+ if drop_score > float(best_by_span.get(key, {}).get("drop_score", -1.0)):
241
  best_by_span[key] = item
242
  return [best_by_span[key] for key in sorted(best_by_span)]
243
 
 
253
  drop_labels = 0
254
  blocked = 0
255
  for item in labels:
256
+ raw_drop_score = item.get("drop_score")
257
+ try:
258
+ drop_score = float(raw_drop_score) if raw_drop_score is not None else None
259
+ except (TypeError, ValueError):
260
+ drop_score = None
261
+ if (
262
+ str(item.get("label") or item.get("entity") or "").upper() != "DROP"
263
+ and (drop_score is None or drop_score < min_score)
264
+ ):
265
  continue
266
  drop_labels += 1
267
  try:
268
  start = int(item["start"])
269
  end = int(item["end"])
270
+ score = drop_score if drop_score is not None else float(item.get("score", 1.0))
271
  except Exception:
272
  blocked += 1
273
  continue