Somasundaram Ayyappan Claude Opus 4.6 (1M context) commited on
Commit
f10912e
·
1 Parent(s): 0d6cd8e

Add section-aware chunked inference for resumes exceeding 512 tokens

Browse files

Splits text at paragraph boundaries (double newlines) and packs sections
into chunks that fit within the model's context window. Character offsets
are mapped back to the original text. Falls back to single-pass for
short inputs. Benchmark uses chunked inference — no regression on val set.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. training/benchmark_structured.py +78 -7
training/benchmark_structured.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import argparse
4
  import json
 
5
  from collections import Counter, defaultdict
6
 
7
  import torch
@@ -43,6 +44,82 @@ def predicted_spans_from_text(text: str, offset_mapping: list[tuple[int, int]],
43
  return text, spans
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def normalize_value(field: str, value: str | None) -> str | None:
47
  if not value:
48
  return None
@@ -153,13 +230,7 @@ def main() -> None:
153
  bucket = str(bucket_info["bucket"])
154
  bucket_totals[bucket]["examples"] += 1
155
 
156
- tokenized = tokenizer(gold_text, return_tensors="pt", return_offsets_mapping=True, truncation=True, max_length=512)
157
- encoded = {k: v for k, v in tokenized.items() if k in ALLOWED_INPUTS}
158
- with torch.no_grad():
159
- pred_ids = model(**encoded).logits.argmax(dim=-1).squeeze(0).cpu().tolist()
160
-
161
- offsets = [tuple(pair) for pair in tokenized["offset_mapping"].squeeze(0).cpu().tolist()][1:-1]
162
- pred_text, pred_spans = predicted_spans_from_text(gold_text, offsets, pred_ids[1:-1])
163
  pred_structured = postprocessor.build_structured_resume_from_spans(pred_spans, pred_text)
164
 
165
  gold_flat = flatten_resume(gold_structured)
 
2
 
3
  import argparse
4
  import json
5
+ import re
6
  from collections import Counter, defaultdict
7
 
8
  import torch
 
44
  return text, spans
45
 
46
 
47
+ def _split_into_sections(text: str) -> list[str]:
48
+ """Split resume text at double-newline boundaries into paragraph blocks."""
49
+ return [block for block in re.split(r"\n{2,}", text) if block.strip()]
50
+
51
+
52
+ def chunked_predicted_spans(
53
+ text: str,
54
+ model,
55
+ tokenizer,
56
+ max_length: int = 512,
57
+ ) -> tuple[str, list]:
58
+ """Run inference with section-aware chunking for texts exceeding max_length.
59
+
60
+ Splits at paragraph boundaries so entities are never cut mid-span.
61
+ Each chunk is a group of consecutive sections that fits within max_length.
62
+ Character offsets are mapped back to the original text.
63
+ """
64
+ num_tokens = len(tokenizer(text, truncation=False)["input_ids"])
65
+
66
+ if num_tokens <= max_length:
67
+ tokenized = tokenizer(text, return_tensors="pt", return_offsets_mapping=True, truncation=True, max_length=max_length)
68
+ encoded = {k: v for k, v in tokenized.items() if k in ALLOWED_INPUTS}
69
+ with torch.no_grad():
70
+ pred_ids = model(**encoded).logits.argmax(dim=-1).squeeze(0).cpu().tolist()
71
+ offsets = [tuple(pair) for pair in tokenized["offset_mapping"].squeeze(0).cpu().tolist()][1:-1]
72
+ return predicted_spans_from_text(text, offsets, pred_ids[1:-1])
73
+
74
+ sections = _split_into_sections(text)
75
+
76
+ chunks: list[str] = []
77
+ chunk_offsets: list[int] = []
78
+ current_sections: list[str] = []
79
+ current_offset = 0
80
+
81
+ for section in sections:
82
+ candidate = "\n\n".join(current_sections + [section]) if current_sections else section
83
+ tok_len = len(tokenizer(candidate, truncation=False)["input_ids"])
84
+ if tok_len > max_length and current_sections:
85
+ chunk_text = "\n\n".join(current_sections)
86
+ chunks.append(chunk_text)
87
+ chunk_offsets.append(current_offset)
88
+ current_offset = text.index(section, current_offset)
89
+ current_sections = [section]
90
+ else:
91
+ if not current_sections:
92
+ current_offset = text.index(section, current_offset)
93
+ current_sections.append(section)
94
+
95
+ if current_sections:
96
+ chunks.append("\n\n".join(current_sections))
97
+ chunk_offsets.append(current_offset)
98
+
99
+ all_spans = []
100
+ for chunk_text, char_offset in zip(chunks, chunk_offsets):
101
+ tokenized = tokenizer(chunk_text, return_tensors="pt", return_offsets_mapping=True, truncation=True, max_length=max_length)
102
+ encoded = {k: v for k, v in tokenized.items() if k in ALLOWED_INPUTS}
103
+ with torch.no_grad():
104
+ pred_ids = model(**encoded).logits.argmax(dim=-1).squeeze(0).cpu().tolist()
105
+ offsets = [tuple(pair) for pair in tokenized["offset_mapping"].squeeze(0).cpu().tolist()][1:-1]
106
+ _, spans = predicted_spans_from_text(chunk_text, offsets, pred_ids[1:-1])
107
+
108
+ for span in spans:
109
+ from training.structured_postprocess import Span
110
+
111
+ all_spans.append(Span(
112
+ label=span.label,
113
+ text=span.text,
114
+ start=span.start + char_offset,
115
+ end=span.end + char_offset,
116
+ bio=span.bio,
117
+ score=span.score,
118
+ ))
119
+
120
+ return text, all_spans
121
+
122
+
123
  def normalize_value(field: str, value: str | None) -> str | None:
124
  if not value:
125
  return None
 
230
  bucket = str(bucket_info["bucket"])
231
  bucket_totals[bucket]["examples"] += 1
232
 
233
+ pred_text, pred_spans = chunked_predicted_spans(gold_text, model, tokenizer)
 
 
 
 
 
 
234
  pred_structured = postprocessor.build_structured_resume_from_spans(pred_spans, pred_text)
235
 
236
  gold_flat = flatten_resume(gold_structured)