feat: add compression classifier score calibration
Browse files
README.md
CHANGED
|
@@ -48,8 +48,9 @@ Live Space:
|
|
| 48 |
- Mount `classifier_manifest.json`, tokenizer files, and optional `model.onnx`;
|
| 49 |
set `TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR` to let the Space use artifact DROP
|
| 50 |
labels through ONNX Runtime or the manifest fallback. ONNX labels are
|
| 51 |
-
evaluated in chunked windows using manifest `max_length` and `stride`;
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
Deploy:
|
| 55 |
|
|
|
|
| 48 |
- Mount `classifier_manifest.json`, tokenizer files, and optional `model.onnx`;
|
| 49 |
set `TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR` to let the Space use artifact DROP
|
| 50 |
labels through ONNX Runtime or the manifest fallback. ONNX labels are
|
| 51 |
+
evaluated in chunked windows using manifest `max_length` and `stride`; mounted
|
| 52 |
+
ONNX labels expose `keep_score`, `drop_score`, and `drop_score_threshold`.
|
| 53 |
+
DROP spans still pass through protected-span and deletion-only safety gates.
|
| 54 |
|
| 55 |
Deploy:
|
| 56 |
|
app.py
CHANGED
|
@@ -207,29 +207,37 @@ def _onnx_labels(
|
|
| 207 |
str(key): value
|
| 208 |
for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items()
|
| 209 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
best_by_span: dict[tuple[int, int], dict[str, Any]] = {}
|
| 211 |
for chunk_index, chunk_logits in enumerate(logits):
|
| 212 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[chunk_index])
|
| 213 |
for token_index, token_logits in enumerate(chunk_logits):
|
| 214 |
probs = _softmax(np.asarray(token_logits, dtype=float).tolist())
|
| 215 |
-
|
|
|
|
|
|
|
| 216 |
start, end = offsets[chunk_index][token_index]
|
| 217 |
start = int(start)
|
| 218 |
end = int(end)
|
| 219 |
if end <= start:
|
| 220 |
continue
|
| 221 |
-
score = round(float(probs[label_id]), 6)
|
| 222 |
key = (start, end)
|
| 223 |
item = {
|
| 224 |
"token": tokens[token_index],
|
| 225 |
"label": id2label.get(str(label_id), "KEEP"),
|
| 226 |
-
"score":
|
|
|
|
|
|
|
|
|
|
| 227 |
"start": start,
|
| 228 |
"end": end,
|
| 229 |
"source": "onnx_token_classifier",
|
| 230 |
"chunk_index": chunk_index,
|
| 231 |
}
|
| 232 |
-
if
|
| 233 |
best_by_span[key] = item
|
| 234 |
return [best_by_span[key] for key in sorted(best_by_span)]
|
| 235 |
|
|
@@ -245,13 +253,21 @@ def _safe_classifier_drop_ranges(
|
|
| 245 |
drop_labels = 0
|
| 246 |
blocked = 0
|
| 247 |
for item in labels:
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
continue
|
| 250 |
drop_labels += 1
|
| 251 |
try:
|
| 252 |
start = int(item["start"])
|
| 253 |
end = int(item["end"])
|
| 254 |
-
score = float(item.get("score", 1.0))
|
| 255 |
except Exception:
|
| 256 |
blocked += 1
|
| 257 |
continue
|
|
|
|
| 207 |
str(key): value
|
| 208 |
for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items()
|
| 209 |
}
|
| 210 |
+
label2id = {str(value).upper(): int(key) for key, value in id2label.items()}
|
| 211 |
+
keep_id = label2id.get("KEEP", 0)
|
| 212 |
+
drop_id = label2id.get("DROP", 1)
|
| 213 |
+
drop_score_threshold = float(manifest.get("drop_score_threshold", 0.5))
|
| 214 |
best_by_span: dict[tuple[int, int], dict[str, Any]] = {}
|
| 215 |
for chunk_index, chunk_logits in enumerate(logits):
|
| 216 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[chunk_index])
|
| 217 |
for token_index, token_logits in enumerate(chunk_logits):
|
| 218 |
probs = _softmax(np.asarray(token_logits, dtype=float).tolist())
|
| 219 |
+
keep_score = float(probs[keep_id]) if keep_id < len(probs) else 0.0
|
| 220 |
+
drop_score = float(probs[drop_id]) if drop_id < len(probs) else 0.0
|
| 221 |
+
label_id = drop_id if drop_score >= drop_score_threshold else keep_id
|
| 222 |
start, end = offsets[chunk_index][token_index]
|
| 223 |
start = int(start)
|
| 224 |
end = int(end)
|
| 225 |
if end <= start:
|
| 226 |
continue
|
|
|
|
| 227 |
key = (start, end)
|
| 228 |
item = {
|
| 229 |
"token": tokens[token_index],
|
| 230 |
"label": id2label.get(str(label_id), "KEEP"),
|
| 231 |
+
"score": round(drop_score if label_id == drop_id else keep_score, 6),
|
| 232 |
+
"keep_score": round(keep_score, 6),
|
| 233 |
+
"drop_score": round(drop_score, 6),
|
| 234 |
+
"drop_score_threshold": drop_score_threshold,
|
| 235 |
"start": start,
|
| 236 |
"end": end,
|
| 237 |
"source": "onnx_token_classifier",
|
| 238 |
"chunk_index": chunk_index,
|
| 239 |
}
|
| 240 |
+
if drop_score > float(best_by_span.get(key, {}).get("drop_score", -1.0)):
|
| 241 |
best_by_span[key] = item
|
| 242 |
return [best_by_span[key] for key in sorted(best_by_span)]
|
| 243 |
|
|
|
|
| 253 |
drop_labels = 0
|
| 254 |
blocked = 0
|
| 255 |
for item in labels:
|
| 256 |
+
raw_drop_score = item.get("drop_score")
|
| 257 |
+
try:
|
| 258 |
+
drop_score = float(raw_drop_score) if raw_drop_score is not None else None
|
| 259 |
+
except (TypeError, ValueError):
|
| 260 |
+
drop_score = None
|
| 261 |
+
if (
|
| 262 |
+
str(item.get("label") or item.get("entity") or "").upper() != "DROP"
|
| 263 |
+
and (drop_score is None or drop_score < min_score)
|
| 264 |
+
):
|
| 265 |
continue
|
| 266 |
drop_labels += 1
|
| 267 |
try:
|
| 268 |
start = int(item["start"])
|
| 269 |
end = int(item["end"])
|
| 270 |
+
score = drop_score if drop_score is not None else float(item.get("score", 1.0))
|
| 271 |
except Exception:
|
| 272 |
blocked += 1
|
| 273 |
continue
|