feat: chunk classifier artifact windows
Browse files
README.md
CHANGED
|
@@ -44,8 +44,9 @@ Live Space:
|
|
| 44 |
managed `inputs[]` batches with per-item receipts and partial-error rows.
|
| 45 |
- Mount `classifier_manifest.json`, tokenizer files, and optional `model.onnx`;
|
| 46 |
set `TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR` to let the Space use artifact DROP
|
| 47 |
-
labels through ONNX Runtime or the manifest fallback.
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
Deploy:
|
| 51 |
|
|
|
|
| 44 |
managed `inputs[]` batches with per-item receipts and partial-error rows.
|
| 45 |
- Mount `classifier_manifest.json`, tokenizer files, and optional `model.onnx`;
|
| 46 |
set `TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR` to let the Space use artifact DROP
|
| 47 |
+
labels through ONNX Runtime or the manifest fallback. ONNX labels are
|
| 48 |
+
evaluated in chunked windows using manifest `max_length` and `stride`; labels
|
| 49 |
+
still pass through protected-span and deletion-only safety gates.
|
| 50 |
|
| 51 |
Deploy:
|
| 52 |
|
app.py
CHANGED
|
@@ -193,33 +193,45 @@ def _onnx_labels(
|
|
| 193 |
return_tensors="np",
|
| 194 |
truncation=True,
|
| 195 |
max_length=int(manifest.get("max_length", 512)),
|
|
|
|
|
|
|
|
|
|
| 196 |
)
|
| 197 |
-
offsets = encoded.pop("offset_mapping")
|
|
|
|
| 198 |
session = _get_onnx_session(str(model_path))
|
| 199 |
input_names = {item.name for item in session.get_inputs()}
|
| 200 |
inputs = {key: value for key, value in encoded.items() if key in input_names}
|
| 201 |
-
logits = session.run(None, inputs)[0]
|
| 202 |
id2label = {
|
| 203 |
str(key): value
|
| 204 |
for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items()
|
| 205 |
}
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
|
| 225 |
def _safe_classifier_drop_ranges(
|
|
|
|
| 193 |
return_tensors="np",
|
| 194 |
truncation=True,
|
| 195 |
max_length=int(manifest.get("max_length", 512)),
|
| 196 |
+
stride=int(manifest.get("stride", 0)),
|
| 197 |
+
return_overflowing_tokens=True,
|
| 198 |
+
padding=True,
|
| 199 |
)
|
| 200 |
+
offsets = encoded.pop("offset_mapping")
|
| 201 |
+
input_ids = encoded["input_ids"]
|
| 202 |
session = _get_onnx_session(str(model_path))
|
| 203 |
input_names = {item.name for item in session.get_inputs()}
|
| 204 |
inputs = {key: value for key, value in encoded.items() if key in input_names}
|
| 205 |
+
logits = session.run(None, inputs)[0]
|
| 206 |
id2label = {
|
| 207 |
str(key): value
|
| 208 |
for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items()
|
| 209 |
}
|
| 210 |
+
best_by_span: dict[tuple[int, int], dict[str, Any]] = {}
|
| 211 |
+
for chunk_index, chunk_logits in enumerate(logits):
|
| 212 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[chunk_index])
|
| 213 |
+
for token_index, token_logits in enumerate(chunk_logits):
|
| 214 |
+
probs = _softmax(np.asarray(token_logits, dtype=float).tolist())
|
| 215 |
+
label_id = int(np.argmax(probs))
|
| 216 |
+
start, end = offsets[chunk_index][token_index]
|
| 217 |
+
start = int(start)
|
| 218 |
+
end = int(end)
|
| 219 |
+
if end <= start:
|
| 220 |
+
continue
|
| 221 |
+
score = round(float(probs[label_id]), 6)
|
| 222 |
+
key = (start, end)
|
| 223 |
+
item = {
|
| 224 |
+
"token": tokens[token_index],
|
| 225 |
+
"label": id2label.get(str(label_id), "KEEP"),
|
| 226 |
+
"score": score,
|
| 227 |
+
"start": start,
|
| 228 |
+
"end": end,
|
| 229 |
+
"source": "onnx_token_classifier",
|
| 230 |
+
"chunk_index": chunk_index,
|
| 231 |
+
}
|
| 232 |
+
if score > float(best_by_span.get(key, {}).get("score", -1.0)):
|
| 233 |
+
best_by_span[key] = item
|
| 234 |
+
return [best_by_span[key] for key in sorted(best_by_span)]
|
| 235 |
|
| 236 |
|
| 237 |
def _safe_classifier_drop_ranges(
|