wchen22 commited on
Commit
b5632b3
·
verified ·
1 Parent(s): e4bdbeb

feat: chunk classifier artifact windows

Browse files
Files changed (2) hide show
  1. README.md +3 -2
  2. app.py +31 -19
README.md CHANGED
@@ -44,8 +44,9 @@ Live Space:
44
  managed `inputs[]` batches with per-item receipts and partial-error rows.
45
  - Mount `classifier_manifest.json`, tokenizer files, and optional `model.onnx`;
46
  set `TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR` to let the Space use artifact DROP
47
- labels through ONNX Runtime or the manifest fallback. Those labels still pass
48
- through protected-span and deletion-only safety gates.
 
49
 
50
  Deploy:
51
 
 
44
  managed `inputs[]` batches with per-item receipts and partial-error rows.
45
  - Mount `classifier_manifest.json`, tokenizer files, and optional `model.onnx`;
46
  set `TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR` to let the Space use artifact DROP
47
+ labels through ONNX Runtime or the manifest fallback. ONNX labels are
48
+ evaluated in chunked windows using manifest `max_length` and `stride`; labels
49
+ still pass through protected-span and deletion-only safety gates.
50
 
51
  Deploy:
52
 
app.py CHANGED
@@ -193,33 +193,45 @@ def _onnx_labels(
193
  return_tensors="np",
194
  truncation=True,
195
  max_length=int(manifest.get("max_length", 512)),
 
 
 
196
  )
197
- offsets = encoded.pop("offset_mapping")[0]
 
198
  session = _get_onnx_session(str(model_path))
199
  input_names = {item.name for item in session.get_inputs()}
200
  inputs = {key: value for key, value in encoded.items() if key in input_names}
201
- logits = session.run(None, inputs)[0][0]
202
  id2label = {
203
  str(key): value
204
  for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items()
205
  }
206
- tokens = tokenizer.convert_ids_to_tokens(encoded["input_ids"][0])
207
- labels = []
208
- for index, token_logits in enumerate(logits):
209
- probs = _softmax(np.asarray(token_logits, dtype=float).tolist())
210
- label_id = int(np.argmax(probs))
211
- start, end = offsets[index]
212
- if int(end) <= int(start):
213
- continue
214
- labels.append({
215
- "token": tokens[index],
216
- "label": id2label.get(str(label_id), "KEEP"),
217
- "score": round(float(probs[label_id]), 6),
218
- "start": int(start),
219
- "end": int(end),
220
- "source": "onnx_token_classifier",
221
- })
222
- return labels
 
 
 
 
 
 
 
 
223
 
224
 
225
  def _safe_classifier_drop_ranges(
 
193
  return_tensors="np",
194
  truncation=True,
195
  max_length=int(manifest.get("max_length", 512)),
196
+ stride=int(manifest.get("stride", 0)),
197
+ return_overflowing_tokens=True,
198
+ padding=True,
199
  )
200
+ offsets = encoded.pop("offset_mapping")
201
+ input_ids = encoded["input_ids"]
202
  session = _get_onnx_session(str(model_path))
203
  input_names = {item.name for item in session.get_inputs()}
204
  inputs = {key: value for key, value in encoded.items() if key in input_names}
205
+ logits = session.run(None, inputs)[0]
206
  id2label = {
207
  str(key): value
208
  for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items()
209
  }
210
+ best_by_span: dict[tuple[int, int], dict[str, Any]] = {}
211
+ for chunk_index, chunk_logits in enumerate(logits):
212
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[chunk_index])
213
+ for token_index, token_logits in enumerate(chunk_logits):
214
+ probs = _softmax(np.asarray(token_logits, dtype=float).tolist())
215
+ label_id = int(np.argmax(probs))
216
+ start, end = offsets[chunk_index][token_index]
217
+ start = int(start)
218
+ end = int(end)
219
+ if end <= start:
220
+ continue
221
+ score = round(float(probs[label_id]), 6)
222
+ key = (start, end)
223
+ item = {
224
+ "token": tokens[token_index],
225
+ "label": id2label.get(str(label_id), "KEEP"),
226
+ "score": score,
227
+ "start": start,
228
+ "end": end,
229
+ "source": "onnx_token_classifier",
230
+ "chunk_index": chunk_index,
231
+ }
232
+ if score > float(best_by_span.get(key, {}).get("score", -1.0)):
233
+ best_by_span[key] = item
234
+ return [best_by_span[key] for key in sorted(best_by_span)]
235
 
236
 
237
  def _safe_classifier_drop_ranges(