parthsarin commited on
Commit
f69ad93
·
verified ·
1 Parent(s): 81614fc

Initial upload: ModernBERT-base chunk classifier (stage 1 of funding-extraction cascade)

Browse files
Files changed (5) hide show
  1. README.md +203 -0
  2. modeling.py +34 -0
  3. pytorch_model.bin +3 -0
  4. tokenizer.json +0 -0
  5. tokenizer_config.json +16 -0
README.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc0-1.0
3
+ base_model: answerdotai/ModernBERT-base
4
+ library_name: transformers
5
+ pipeline_tag: text-classification
6
+ tags:
7
+ - funding-extraction
8
+ - arxiv
9
+ - scholarly-communication
10
+ - chunk-classification
11
+ - modernbert
12
+ language:
13
+ - en
14
+ datasets:
15
+ - cometadata/arxiv-pdf-only-works-funding-statement-extraction-train-test
16
+ ---
17
+
18
+ # ModernBERT-base Chunk Classifier — Funding Statement Localization
19
+
20
+ A binary classifier on top of `answerdotai/ModernBERT-base` that scores a
21
+ single 8,192-token chunk of an academic paper for the presence of a funding
22
+ statement. Used as **stage 1 of a three-stage funding-extraction cascade** to
23
+ narrow a long PDF down to the most-likely chunk before running expensive
24
+ span-extraction and cleanup.
25
+
26
+ The full cascade:
27
+
28
+ 1. **Stage 1 (this model)**: For each ≤8,192-token chunk of the paper,
29
+ predict a scalar `P(this chunk contains a funding statement)`. Take top-K
30
+ chunks above a threshold (we use top-2 above 0.4).
31
+ 2. **Stage 2 — span head**:
32
+ [`cometadata/funding-extraction-modernbert-base-spanhead`](https://huggingface.co/cometadata/funding-extraction-modernbert-base-spanhead)
33
+ — picks the exact start/end token within the top chunk.
34
+ 3. **Stage 3 — cleanup LoRA**:
35
+ [`cometadata/funding-cleaning-qwen3-4b-lora`](https://huggingface.co/cometadata/funding-cleaning-qwen3-4b-lora)
36
+ — strips LaTeX markers and normalizes whitespace in the extracted span.
37
+
38
+ You can use this model standalone if you only need to flag whether a chunk
39
+ (or doc) contains funding language at all (binary F1 0.97 on the test set).
40
+
41
+ ## Architecture
42
+
43
+ The architecture is a custom `ChunkClassifier` module (included in
44
+ `modeling.py`):
45
+
46
+ ```python
47
+ import torch.nn as nn
48
+ from transformers import AutoModel
49
+
50
+
51
+ class ChunkClassifier(nn.Module):
52
+ """ModernBERT encoder + mean-pool + binary head."""
53
+
54
+ def __init__(self, base="answerdotai/ModernBERT-base"):
55
+ super().__init__()
56
+ self.encoder = AutoModel.from_pretrained(base)
57
+ self.head = nn.Linear(self.encoder.config.hidden_size, 1)
58
+
59
+ def forward(self, input_ids, attention_mask):
60
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
61
+ # Mean pool over real (non-padding) tokens
62
+ mask = attention_mask.unsqueeze(-1).float()
63
+ pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
64
+ return self.head(pooled).squeeze(-1) # one logit per chunk
65
+ ```
66
+
67
+ ## Use
68
+
69
+ ```python
70
+ import torch
71
+ from huggingface_hub import hf_hub_download
72
+ from transformers import AutoTokenizer
73
+ from modeling import ChunkClassifier # bundled in this repo
74
+
75
+ REPO = "cometadata/funding-chunk-classifier-modernbert-base"
76
+ device = "cuda"
77
+
78
+ tokenizer = AutoTokenizer.from_pretrained(REPO)
79
+ model = ChunkClassifier("answerdotai/ModernBERT-base").to(device)
80
+ state_dict = torch.load(
81
+ hf_hub_download(REPO, "pytorch_model.bin"),
82
+ map_location=device, weights_only=True,
83
+ )
84
+ model.load_state_dict(state_dict)
85
+ model.eval()
86
+
87
+ # For a long paper, slide an 8192-token window with stride 4096.
88
+ def chunks_of(text, max_tok=8192, stride=4096):
89
+ enc = tokenizer(text, add_special_tokens=False, truncation=False)
90
+ ids = enc["input_ids"]
91
+ if len(ids) <= max_tok:
92
+ yield ids, 0, len(ids)
93
+ return
94
+ for st in range(0, len(ids), stride):
95
+ en = min(st + max_tok, len(ids))
96
+ yield ids[st:en], st, en
97
+ if en == len(ids):
98
+ break
99
+
100
+ probs = []
101
+ for chunk_ids, st, en in chunks_of(paper_text):
102
+ ids_t = torch.tensor(chunk_ids).unsqueeze(0).to(device)
103
+ attn = torch.ones_like(ids_t)
104
+ with torch.no_grad():
105
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
106
+ logit = model(ids_t, attn).float()
107
+ probs.append((torch.sigmoid(logit).item(), st, en))
108
+
109
+ # Top-K chunks above threshold
110
+ top_k = sorted(probs, key=lambda p: -p[0])[:2]
111
+ top_k = [p for p in top_k if p[0] >= 0.4]
112
+ # `top_k` is the list to hand off to the span-head model.
113
+ ```
114
+
115
+ ## Training data
116
+
117
+ Built from the 2,384 training rows of
118
+ `cometadata/arxiv-pdf-only-works-funding-statement-extraction-train-test`.
119
+
120
+ For each train doc:
121
+ - Tokenize `vlm_markdown` with the ModernBERT tokenizer.
122
+ - Slide an 8,192-token window with stride 4,096 over the tokenized doc.
123
+ - For each chunk, label `1` iff the gold funding statement (located via
124
+ verbatim substring or `rapidfuzz.partial_ratio_alignment ≥ 0.7`) overlaps
125
+ the chunk's character range by more than half its length, else `0`.
126
+
127
+ Negative docs (no funding statement) contribute negative chunks; positive
128
+ docs contribute one positive chunk (the one containing the gold) plus several
129
+ negative chunks from the rest of the doc, so the negative class is
130
+ naturally dominant (~9× more negatives than positives).
131
+
132
+ Final training set: roughly 21,000 chunks (~2,300 positive / ~18,700
133
+ negative).
134
+
135
+ ## Loss
136
+
137
+ Binary cross-entropy with `pos_weight = n_examples / n_positives` to
138
+ counteract the class imbalance:
139
+
140
+ ```python
141
+ loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(n_examples / n_positives))
142
+ loss = loss_fn(logits, labels)
143
+ ```
144
+
145
+ ## Hyperparameters
146
+
147
+ - Base: `answerdotai/ModernBERT-base` (149M, 8,192-token context)
148
+ - Optimizer: AdamW, lr 5e-5, weight decay 0.01
149
+ - Schedule: linear warmup (20 steps) + cosine decay
150
+ - Epochs: 3
151
+ - Batch: 2 per device × 8 grad accum = 16 effective
152
+ - Mixed precision: bfloat16
153
+ - Max sequence: 8,192 tokens
154
+ - Trained on 1× H100 80GB
155
+ - Saved checkpoint: `pytorch_model.bin` is the epoch-2 (final) state dict
156
+
157
+ ## Evaluation
158
+
159
+ On the 597-row test split of
160
+ `cometadata/arxiv-pdf-only-works-funding-statement-extraction-train-test`,
161
+ treated as a **per-document binary task** (does the doc have any funding
162
+ statement?): we score each candidate chunk and use the max probability as
163
+ the document-level prediction. Threshold = 0.5.
164
+
165
+ | Metric | Precision | Recall | F1 | F0.5 |
166
+ |------------------------------|-----------|--------|--------|--------|
167
+ | Doc-level funding detection | 0.9831 | 0.9537 | 0.9682 | 0.9771 |
168
+
169
+ Sub-stats at threshold 0.5: TP=350, FP=6, FN=17, TN=224.
170
+
171
+ **Chunk-recall caveat**: even when the doc-level prediction is correct, the
172
+ **top-1 chunk** contains the gold statement verbatim only ~68% of the time
173
+ (top-2 covers ~88%). This is why the downstream cascade uses **top-K=2**
174
+ chunks: it raises the chance that the gold-containing chunk is fed to the
175
+ span head.
176
+
177
+ ## Intended use
178
+
179
+ Doc-level filtering of arXiv-derived PDFs for funding-statement presence, and
180
+ stage-1 of the funding-extraction cascade. Useful when you want to skip
181
+ expensive span extraction on most papers (a sizable fraction of arXiv papers
182
+ have no funding statement).
183
+
184
+ Not intended for: extraction (it only classifies chunks; pair with the
185
+ span-head model for spans), classification of funding sources, or text
186
+ outside the academic-paper domain.
187
+
188
+ ## Limitations
189
+
190
+ - Trained only on arXiv-derived PDFs; behavior on other paper sources is
191
+ untested.
192
+ - Top-1 chunk is wrong ~32% of the time even when doc-level is correct. Use
193
+ top-K ≥ 2 if you need recall.
194
+ - Mean-pooling over 8,192 tokens dilutes the signal from a short
195
+ (~272-char-median) funding statement — the false-negative rate at strict
196
+ threshold 0.9 is non-trivial. Use 0.5 (or lower) and rely on the span
197
+ head's `no_answer` head to suppress empty chunks.
198
+
199
+ ## Citation / acknowledgement
200
+
201
+ Trained as part of an applied research cycle on the
202
+ `cometadata/arxiv-pdf-only-works-funding-statement-extraction-train-test`
203
+ dataset by Comet.
modeling.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom model class for funding-chunk-classifier-modernbert-base.
2
+
3
+ Usage:
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from transformers import AutoTokenizer
7
+ from modeling import ChunkClassifier
8
+
9
+ REPO = "cometadata/funding-chunk-classifier-modernbert-base"
10
+ tokenizer = AutoTokenizer.from_pretrained(REPO)
11
+ model = ChunkClassifier().to("cuda")
12
+ sd = torch.load(hf_hub_download(REPO, "pytorch_model.bin"),
13
+ map_location="cuda", weights_only=True)
14
+ model.load_state_dict(sd)
15
+ model.eval()
16
+ """
17
+ import torch.nn as nn
18
+ from transformers import AutoModel
19
+
20
+
21
+ class ChunkClassifier(nn.Module):
22
+ """ModernBERT-base encoder + mean-pool + binary head for funding-chunk detection."""
23
+
24
+ def __init__(self, base: str = "answerdotai/ModernBERT-base"):
25
+ super().__init__()
26
+ self.encoder = AutoModel.from_pretrained(base)
27
+ self.head = nn.Linear(self.encoder.config.hidden_size, 1)
28
+
29
+ def forward(self, input_ids, attention_mask):
30
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
31
+ # Mean pool over real (non-padding) tokens
32
+ mask = attention_mask.unsqueeze(-1).float()
33
+ pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
34
+ return self.head(pooled).squeeze(-1) # one logit per chunk
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:566e172d7db3c9201011533a74503592384882622568f171b27bd47f1708e5ba
3
+ size 596119575
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "clean_up_tokenization_spaces": true,
4
+ "cls_token": "[CLS]",
5
+ "is_local": false,
6
+ "mask_token": "[MASK]",
7
+ "model_input_names": [
8
+ "input_ids",
9
+ "attention_mask"
10
+ ],
11
+ "model_max_length": 8192,
12
+ "pad_token": "[PAD]",
13
+ "sep_token": "[SEP]",
14
+ "tokenizer_class": "TokenizersBackend",
15
+ "unk_token": "[UNK]"
16
+ }