parthsarin commited on
Commit
1e00313
·
verified ·
1 Parent(s): 16d2845

Initial upload: ModernBERT-base span head for funding statement extraction

Browse files
Files changed (5) hide show
  1. README.md +235 -0
  2. modeling.py +43 -0
  3. pytorch_model.bin +3 -0
  4. tokenizer.json +0 -0
  5. tokenizer_config.json +16 -0
README.md ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc0-1.0
3
+ base_model: answerdotai/ModernBERT-base
4
+ library_name: transformers
5
+ pipeline_tag: token-classification
6
+ tags:
7
+ - funding-extraction
8
+ - arxiv
9
+ - scholarly-communication
10
+ - span-extraction
11
+ - modernbert
12
+ language:
13
+ - en
14
+ datasets:
15
+ - cometadata/arxiv-pdf-only-works-funding-statement-extraction-train-test
16
+ ---
17
+
18
+ # ModernBERT-base Span-Head — Funding Statement Extraction
19
+
20
+ A custom span-extraction head on top of `answerdotai/ModernBERT-base`. Given a
21
+ chunk of an academic paper (up to 8,192 tokens), it predicts the start and end
22
+ token positions of a funding statement, plus a "no-answer" probability for
23
+ documents with no funding statement.
24
+
25
+ This is the **rough-extraction stage** of a two-stage cascade:
26
+
27
+ 1. **Stage 1 (this model)**: ModernBERT-base + span head — finds the rough
28
+ span (≈ best@0.85 F1 0.95 on the test set).
29
+ 2. **Stage 2 (separate)**: `cometadata/funding-cleaning-qwen3-4b-lora` —
30
+ cleans the rough span into the canonical, normalized funding statement
31
+ (strips LaTeX markers, joins paragraph breaks, etc.).
32
+
33
+ Use this model alone if you only need approximate localization; chain with the
34
+ cleanup LoRA if you need the cleaned canonical text.
35
+
36
+ ## Architecture
37
+
38
+ The architecture is a custom `SpanHead` module (included in `modeling.py`):
39
+
40
+ ```python
41
+ import torch
42
+ import torch.nn as nn
43
+ from transformers import AutoModel
44
+
45
+
46
+ class SpanHead(nn.Module):
47
+ """ModernBERT encoder + start/end/no-answer heads."""
48
+
49
+ def __init__(self, base="answerdotai/ModernBERT-base"):
50
+ super().__init__()
51
+ self.encoder = AutoModel.from_pretrained(base)
52
+ h = self.encoder.config.hidden_size # 768
53
+ self.start_head = nn.Linear(h, 1)
54
+ self.end_head = nn.Linear(h, 1)
55
+ self.no_answer_head = nn.Linear(h, 1)
56
+ self.dropout = nn.Dropout(0.1)
57
+
58
+ def forward(self, input_ids, attention_mask):
59
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
60
+ hidden = self.dropout(out.last_hidden_state)
61
+ start_logits = self.start_head(hidden).squeeze(-1)
62
+ end_logits = self.end_head(hidden).squeeze(-1)
63
+ # Mean-pool for no-answer
64
+ mask = attention_mask.unsqueeze(-1).float()
65
+ pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
66
+ no_answer = self.no_answer_head(pooled).squeeze(-1)
67
+ return start_logits, end_logits, no_answer
68
+ ```
69
+
70
+ ## Use
71
+
72
+ ```python
73
+ import torch
74
+ from huggingface_hub import hf_hub_download
75
+ from transformers import AutoTokenizer
76
+ from modeling import SpanHead # bundled in this repo
77
+
78
+ REPO = "cometadata/funding-extraction-modernbert-base-spanhead"
79
+ device = "cuda"
80
+
81
+ tokenizer = AutoTokenizer.from_pretrained(REPO)
82
+ model = SpanHead("answerdotai/ModernBERT-base").to(device)
83
+ state_dict = torch.load(
84
+ hf_hub_download(REPO, "pytorch_model.bin"),
85
+ map_location=device, weights_only=True,
86
+ )
87
+ model.load_state_dict(state_dict)
88
+ model.eval()
89
+
90
+ # `chunk_text` should be a ≤8192-token chunk of the paper (e.g., the
91
+ # acknowledgments-containing region). For long papers, run the model on
92
+ # sliding 8192-tok windows (stride 4096) and pick the chunk with the lowest
93
+ # no-answer probability.
94
+
95
+ enc = tokenizer(chunk_text, return_offsets_mapping=True,
96
+ add_special_tokens=False, truncation=True, max_length=8192)
97
+ ids = torch.tensor(enc["input_ids"]).unsqueeze(0).to(device)
98
+ attn = torch.ones_like(ids)
99
+
100
+ with torch.no_grad():
101
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
102
+ start_logits, end_logits, no_answer = model(ids, attn)
103
+
104
+ start_logits = start_logits.squeeze(0).float().cpu()
105
+ end_logits = end_logits.squeeze(0).float().cpu()
106
+ no_answer_prob = torch.sigmoid(no_answer).item()
107
+
108
+ if no_answer_prob >= 0.5:
109
+ pred_span = "" # this chunk has no funding statement
110
+ else:
111
+ start = int(start_logits.argmax())
112
+ # Constrain end to be after start and within ~300 tokens
113
+ end_window = end_logits[start:start + 300]
114
+ end = start + int(end_window.argmax())
115
+ offsets = enc["offset_mapping"]
116
+ char_s = offsets[start][0]
117
+ char_e = offsets[end][1]
118
+ pred_span = chunk_text[char_s:char_e].strip()
119
+ ```
120
+
121
+ ## Training data
122
+
123
+ Built from the 2,384 training rows of
124
+ `cometadata/arxiv-pdf-only-works-funding-statement-extraction-train-test`.
125
+
126
+ For each positive doc (1,416 rows):
127
+ - Tokenize `vlm_markdown` with the ModernBERT tokenizer.
128
+ - Locate the gold funding statement in `vlm_markdown` via verbatim substring,
129
+ or via `rapidfuzz.partial_ratio_alignment` if not verbatim. Convert
130
+ char-span to token-span.
131
+ - Pick the 8,192-token sliding window (stride 4,096) that contains the gold
132
+ span fully. If the doc is ≤ 8,192 tokens, use the whole doc as one chunk.
133
+ - Training labels: `start_tok` and `end_tok` indices within the chunk;
134
+ `no_answer = 0`.
135
+
136
+ For each negative doc (968 rows):
137
+ - Use the last 8,192-token chunk of the doc (since funding statements, when
138
+ they exist, are typically near the end).
139
+ - Training labels: `start_tok = end_tok = 0`; `no_answer = 1`.
140
+
141
+ About ~5% of positive rows where no fuzzy alignment ≥ 0.7 could be found are
142
+ dropped. Final training set: ~3,300 chunks.
143
+
144
+ ## Loss
145
+
146
+ ```
147
+ loss = CE(start_logits[no_answer==0], gold_start)
148
+ + CE(end_logits[no_answer==0], gold_end)
149
+ + 1.0 * BCE_with_logits(no_answer_logit, no_answer_label)
150
+ ```
151
+
152
+ The start/end CE is masked out on negative chunks; the no-answer BCE is
153
+ computed on all chunks. Padded positions in `start_logits`/`end_logits` are
154
+ masked to `-1e4` so they can't be argmax'd.
155
+
156
+ ## Hyperparameters
157
+
158
+ - Base: `answerdotai/ModernBERT-base` (149M, 8,192-token context)
159
+ - Optimizer: AdamW, lr 5e-5, weight decay 0.01
160
+ - Schedule: linear warmup (30 steps) + cosine decay
161
+ - Epochs: 4
162
+ - Batch: 4 per device × 4 grad accum = 16 effective
163
+ - Mixed precision: bfloat16
164
+ - Max sequence: 8,192 tokens
165
+ - Trained on 1× H100 80GB
166
+
167
+ ## Evaluation
168
+
169
+ On the 597-row test split of
170
+ `cometadata/arxiv-pdf-only-works-funding-statement-extraction-train-test`.
171
+ At inference we ran this model on the top-2 chunks selected by a separate
172
+ ModernBERT-base chunk classifier (binary funding-yes, mean-pooled
173
+ classification head) and picked the chunk with the lower no-answer prob.
174
+
175
+ | Metric | Precision | Recall | F1 | F0.5 |
176
+ |---------------------------------------|-----------|--------|--------|--------|
177
+ | Binary detection | 0.9887 | 0.9510 | 0.9694 | 0.9809 |
178
+ | Strict span (`token_sort_ratio≥0.95`) | 0.7365 | 0.7084 | 0.7222 | 0.7307 |
179
+ | Loose span (max-of-4 fuzz ≥ 0.85) | 0.9745 | 0.9373 | 0.9556 | 0.9668 |
180
+
181
+ **Hard ceiling note**: ~28% of test gold statements are not verbatim
182
+ substrings of any source representation in the dataset (the dataset's labels
183
+ were normalized by frontier models — whitespace, LaTeX markers, paragraph
184
+ joins). The 0.95 strict threshold is unforgiving of those normalizations even
185
+ on perfectly extracted source-spans, so strict F1 is capped near 0.73 for any
186
+ single-stage extractive model. The loose-span F1 of 0.96 is closer to the
187
+ practical extractive ceiling.
188
+
189
+ For higher strict F1, chain with `cometadata/funding-cleaning-qwen3-4b-lora`
190
+ which cleans the rough span into the canonical text.
191
+
192
+ ## Cascade pipeline
193
+
194
+ For long papers (> 8,192 tokens), use a chunk-classifier first to pick the
195
+ chunk most likely to contain the funding statement:
196
+
197
+ ```python
198
+ # Pseudocode for the full cascade
199
+ chunks = sliding_windows(doc, max_tok=8192, stride=4096)
200
+ chunk_probs = [chunk_classifier(c) for c in chunks]
201
+ top_chunk = chunks[argmax(chunk_probs)]
202
+ rough_span = spanhead_model(top_chunk) # this model
203
+ clean_span = cleanup_lora(rough_span, top_chunk) # other model
204
+ ```
205
+
206
+ A simple heuristic alternative to the chunk classifier (also works fine):
207
+ just use the last 8,192-token window of the document — funding statements are
208
+ usually near the end. This loses a few percentage points of recall on papers
209
+ with funding info mid-document.
210
+
211
+ ## Intended use
212
+
213
+ Extraction of the **rough span** containing a funding acknowledgment from
214
+ arXiv paper text (or similar academic markdown). Designed to be the first
215
+ stage of a two-stage cascade with the cleanup LoRA, but usable on its own if
216
+ you only need approximate localization.
217
+
218
+ Not intended for: classification of funding sources, downstream
219
+ funder/grant/scheme parsing, or extraction from non-paper text.
220
+
221
+ ## Limitations
222
+
223
+ - Trained on arXiv-derived PDFs only; behavior on other paper sources is
224
+ untested.
225
+ - Outputs a rough span — for canonical, downstream-ready text, chain with the
226
+ cleanup LoRA.
227
+ - Will occasionally pick the wrong sibling sentence when an acknowledgments
228
+ section contains multiple funding statements (each person's own grants);
229
+ this is the dominant failure mode of the strict-F1 evaluation.
230
+
231
+ ## Citation / acknowledgement
232
+
233
+ Trained as part of an applied research cycle on the
234
+ `cometadata/arxiv-pdf-only-works-funding-statement-extraction-train-test`
235
+ dataset by Comet.
modeling.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom model class for funding-extraction-modernbert-base-spanhead.
2
+
3
+ Usage:
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from transformers import AutoTokenizer
7
+ from modeling import SpanHead
8
+
9
+ REPO = "cometadata/funding-extraction-modernbert-base-spanhead"
10
+ tokenizer = AutoTokenizer.from_pretrained(REPO)
11
+ model = SpanHead().to("cuda")
12
+ sd = torch.load(hf_hub_download(REPO, "pytorch_model.bin"),
13
+ map_location="cuda", weights_only=True)
14
+ model.load_state_dict(sd)
15
+ model.eval()
16
+ """
17
+ import torch
18
+ import torch.nn as nn
19
+ from transformers import AutoModel
20
+
21
+
22
+ class SpanHead(nn.Module):
23
+ """ModernBERT-base encoder + start/end/no-answer heads for funding span extraction."""
24
+
25
+ def __init__(self, base: str = "answerdotai/ModernBERT-base"):
26
+ super().__init__()
27
+ self.encoder = AutoModel.from_pretrained(base)
28
+ h = self.encoder.config.hidden_size # 768
29
+ self.start_head = nn.Linear(h, 1)
30
+ self.end_head = nn.Linear(h, 1)
31
+ self.no_answer_head = nn.Linear(h, 1)
32
+ self.dropout = nn.Dropout(0.1)
33
+
34
+ def forward(self, input_ids, attention_mask):
35
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
36
+ hidden = self.dropout(out.last_hidden_state)
37
+ start_logits = self.start_head(hidden).squeeze(-1)
38
+ end_logits = self.end_head(hidden).squeeze(-1)
39
+ # Mean-pool for no-answer
40
+ mask = attention_mask.unsqueeze(-1).float()
41
+ pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
42
+ no_answer = self.no_answer_head(pooled).squeeze(-1)
43
+ return start_logits, end_logits, no_answer
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a5f09370d87bf87db1fedb3502a17327b6eca1f6d34fc75b2187be1dde37bc0
3
+ size 596127249
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "clean_up_tokenization_spaces": true,
4
+ "cls_token": "[CLS]",
5
+ "is_local": false,
6
+ "mask_token": "[MASK]",
7
+ "model_input_names": [
8
+ "input_ids",
9
+ "attention_mask"
10
+ ],
11
+ "model_max_length": 8192,
12
+ "pad_token": "[PAD]",
13
+ "sep_token": "[SEP]",
14
+ "tokenizer_class": "TokenizersBackend",
15
+ "unk_token": "[UNK]"
16
+ }