kadarakos commited on
Commit
ba006b9
·
1 Parent(s): f84045f
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "mentioned"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "kadarakos", email = "kadar.akos@gmail.com" }
8
+ ]
9
+ requires-python = ">=3.12"
10
+ dependencies = [
11
+ "datasets>=4.6.0",
12
+ "huggingface-hub>=1.4.1",
13
+ "lightning>=2.6.1",
14
+ "torch>=2.10.0",
15
+ "torchmetrics>=1.8.2",
16
+ "transformers>=5.2.0",
17
+ "wandb>=0.25.0",
18
+ ]
19
+
20
+ [build-system]
21
+ requires = ["uv_build>=0.9.9,<0.10.0"]
22
+ build-backend = "uv_build"
23
+
24
+ [dependency-groups]
25
+ dev = [
26
+ "ruff>=0.15.2",
27
+ ]
src/mentioned/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def hello() -> str:
2
+ return "Hello from mentioned!"
src/mentioned/data.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ from collections import defaultdict
5
+
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from datasets import load_dataset
9
+
10
+
11
+ def mentions_by_sentence(example):
12
+ mentions_per_sentence = defaultdict(list)
13
+ for cluster in example["coref_chains"]:
14
+ for mention in cluster:
15
+ sent_idx, start, end = mention
16
+ # In the ArrowDataset have to use str or byte as key.
17
+ mentions_per_sentence[str(sent_idx)].append((start, end))
18
+ example["mentions"] = mentions_per_sentence
19
+ return example
20
+
21
+
22
+ def flatten_to_sentences(batch):
23
+ new_batch = {"sentence": [], "mentions": []}
24
+
25
+ # Ensure we are iterating over the lists in the batch
26
+ for sentences, mentions_dict in zip(batch["sentences"], batch["mentions"]):
27
+ # Some versions of datasets might save dicts as None if empty
28
+ if mentions_dict is None:
29
+ mentions_dict = {}
30
+
31
+ for i, sent in enumerate(sentences):
32
+ # Safe access: get the list of mentions or empty list
33
+ sent_mentions = mentions_dict.get(str(i), [])
34
+
35
+ new_batch["sentence"].append(sent)
36
+ new_batch["mentions"].append(sent_mentions)
37
+
38
+ return new_batch
39
+
40
+
41
+ class LitBankStringDataset(Dataset):
42
+ def __init__(self, hf_dataset):
43
+ self.dataset = hf_dataset
44
+
45
+ def __len__(self):
46
+ return len(self.dataset)
47
+
48
+ def __getitem__(self, idx):
49
+ item = self.dataset[idx]
50
+ tokens = item["sentence"]
51
+ # The ArrowDataset gives None for [].
52
+ mentions = item["mentions"] if item["mentions"] is not None else []
53
+
54
+ n_tokens = len(tokens)
55
+ starts = torch.zeros(n_tokens, dtype=torch.long)
56
+ span_labels = torch.zeros((n_tokens, n_tokens), dtype=torch.long)
57
+
58
+ for s, e in mentions:
59
+ # Ensure indices are within bounds (LitBank e is often inclusive)
60
+ if s < n_tokens and e < n_tokens:
61
+ starts[s] = 1
62
+ span_labels[s, e] = 1
63
+
64
+ return {
65
+ "tokens": tokens,
66
+ "starts": starts,
67
+ "span_labels": span_labels,
68
+ }
69
+
70
+
71
+ def collate_fn(batch):
72
+ sentences = [item["tokens"] for item in batch]
73
+ # Padding up to longest sentence.
74
+ max_len = max(len(s) for s in sentences)
75
+ starts_list = [] # 0 - 1 indicator for start tokens.
76
+ spans_list = [] # 0 - 1 indicator for (start, end) pairs.
77
+
78
+ for item in batch:
79
+ curr_len = len(item["starts"])
80
+ starts_list.append(item["starts"])
81
+ padded_span = torch.zeros((max_len, max_len), dtype=torch.long)
82
+ padded_span[:curr_len, :curr_len] = item["span_labels"]
83
+ spans_list.append(padded_span)
84
+
85
+ # 1D padding for token classification.
86
+ starts_padded = pad_sequence(starts_list, batch_first=True, padding_value=-1)
87
+ token_mask = starts_padded != -1
88
+ starts_padded[starts_padded == -1] = 0
89
+
90
+ # 2D padding for token-pair classification: B x N x N
91
+ spans_padded = torch.stack(spans_list)
92
+ # 2D length mask (like attention): B x N x 1 & B x 1 x N -> (B, N, N)
93
+ valid_len_mask = token_mask.unsqueeze(2) & token_mask.unsqueeze(1)
94
+ # 2. Causal j >= i mask: B x N x N
95
+ upper_tri_mask = torch.triu(
96
+ torch.ones((max_len, max_len), dtype=torch.bool),
97
+ diagonal=0,
98
+ )
99
+ # Mask all positions that are not corresponding to a start token: (B X N X 1)
100
+ is_start_mask = starts_padded.unsqueeze(2).bool()
101
+ # Full mask is "and"ing all masks together (like attention): B x N x N
102
+ span_loss_mask = valid_len_mask & upper_tri_mask & is_start_mask
103
+
104
+ return {
105
+ "sentences": sentences,
106
+ "starts": starts_padded, # (B, N) - Targets for start classifier
107
+ "spans": spans_padded, # (B, N, N) - Targets for span classifier
108
+ "token_mask": token_mask, # (B, N) - For 1D loss
109
+ "span_loss_mask": span_loss_mask, # (B, N, N) - For 2D loss
110
+ }
111
+
112
+
113
+ def make_litbank() -> tuple[DataLoader, DataLoader, DataLoader]:
114
+ """Reformat litbank to as a sentence-level mention-detection dataset."""
115
+ litbank = load_dataset("coref-data/litbank_raw", "split_0")
116
+ litbank_sentences_mentions = litbank.map(mentions_by_sentence).map(
117
+ flatten_to_sentences, batched=True, remove_columns=litbank["train"].column_names
118
+ )
119
+ no = 0
120
+ for i in range(len(litbank_sentences_mentions["train"])):
121
+ mentions = litbank_sentences_mentions["train"][i]["mentions"]
122
+ # Check if None or empty
123
+ if mentions is None or len(mentions) == 0:
124
+ no += 1
125
+ print(f"Training sentences without mentions: {no}.")
126
+ train = LitBankStringDataset(litbank_sentences_mentions["train"])
127
+ val = LitBankStringDataset(litbank_sentences_mentions["validation"])
128
+ test = LitBankStringDataset(litbank_sentences_mentions["test"])
129
+ train_loader = DataLoader(train, batch_size=4, shuffle=True, collate_fn=collate_fn)
130
+ val_loader = DataLoader(val, batch_size=4, shuffle=False, collate_fn=collate_fn)
131
+ test_loader = DataLoader(test, batch_size=4, shuffle=False, collate_fn=collate_fn)
132
+ # Sanity check
133
+ try:
134
+ next(iter(train_loader))
135
+ except Exception as e:
136
+ raise e
137
+ return train_loader, val_loader, test_loader
src/mentioned/inference.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from mentioned.model import make_model_v1, LitMentionDetector
5
+
6
+
7
+ class InferenceMentionDetector(nn.Module):
8
+ def __init__(self, encoder, mention_detector):
9
+ super().__init__()
10
+ self.encoder = encoder
11
+ self.mention_detector = mention_detector
12
+
13
+ def forward(self, input_ids, attention_mask, word_ids):
14
+ """
15
+ Inputs (Tensors):
16
+ input_ids: (B, Seq_Len)
17
+ attention_mask: (B, Seq_Len)
18
+ word_ids: (B, Seq_Len) -> Word index per token, -1 for special tokens
19
+
20
+ Returns (Tensors):
21
+ start_probs: (B, Num_Words)
22
+ end_probs: (B, Num_Words, Num_Words)
23
+ """
24
+
25
+ # 1. Subword-to-Word Pooling (The vectorized logic we wrote earlier)
26
+ # Returns: (Batch, Num_Words, Hidden_Dim)
27
+ word_embeddings = self.encoder(
28
+ input_ids=input_ids,
29
+ attention_mask=attention_mask,
30
+ word_ids=word_ids
31
+ )
32
+
33
+ # 2. Mention Detection Logic
34
+ # Returns: start_logits (B, W), end_logits (B, W, W)
35
+ start_logits, end_logits = self.mention_detector(word_embeddings)
36
+
37
+ # 3. Probabilities for Inference
38
+ # Applying sigmoid here makes the ONNX model output final scores
39
+ start_probs = torch.sigmoid(start_logits)
40
+ end_probs = torch.sigmoid(end_logits)
41
+
42
+ return start_probs, end_probs
43
+
44
+
45
+ class MentionProcessor:
46
+ def __init__(self, tokenizer, max_length: int = 512):
47
+ self.tokenizer = tokenizer
48
+ self.max_length = max_length
49
+
50
+ def __call__(self, docs: list[list[str]]):
51
+ """
52
+ Converts raw word lists into tensors for the ONNX model.
53
+ Args:
54
+ docs: List of documents, where each doc is a list of words.
55
+ Example: [["Hello", "world"], ["Testing", "this"]]
56
+ """
57
+ # 1. Standard Tokenization
58
+ # is_split_into_words=True is crucial since your input is list[list[str]]
59
+ inputs = self.tokenizer(
60
+ docs,
61
+ is_split_into_words=True,
62
+ return_tensors="pt",
63
+ truncation=True,
64
+ max_length=self.max_length,
65
+ padding=True,
66
+ return_attention_mask=True
67
+ )
68
+
69
+ # 2. Map Subwords to Word IDs
70
+ # We need a tensor where each token index maps to its word index.
71
+ # Special tokens (<s>, </s>, <pad>) are mapped to -1 to be ignored by pooling.
72
+ batch_word_ids = []
73
+ for i in range(len(docs)):
74
+ # tokenizer.word_ids(i) returns [None, 0, 1, 1, 2, None]
75
+ w_ids = [w if w is not None else -1 for w in inputs.word_ids(batch_index=i)]
76
+ batch_word_ids.append(torch.tensor(w_ids))
77
+
78
+ # 3. Stack into a batch tensor (Batch, Seq_Len)
79
+ word_ids_tensor = torch.stack(batch_word_ids)
80
+
81
+ return {
82
+ "input_ids": inputs["input_ids"],
83
+ "attention_mask": inputs["attention_mask"],
84
+ "word_ids": word_ids_tensor
85
+ }
86
+
87
+
88
+ class MentionDetectorPipeline:
89
+ def __init__(self, model, tokenizer, threshold: float = 0.5):
90
+ """
91
+ Args:
92
+ model: The InferenceMentionDetector (PyTorch or ONNX Session)
93
+ tokenizer: The PreTrainedTokenizer
94
+ threshold: Probability threshold to consider a mention valid
95
+ """
96
+ self.model = model.eval()
97
+ self.processor = MentionProcessor(tokenizer, model.max_length)
98
+ self.threshold = threshold
99
+
100
+ @torch.no_grad()
101
+ def predict(self, docs: list[list[str]]):
102
+ """
103
+ Args:
104
+ docs: List of documents (each is a list of words)
105
+ Returns:
106
+ List of lists containing dicts: {"start": int, "end": int, "score": float}
107
+ """
108
+ # 1. Preprocess to Tensors
109
+ batch = self.processor(docs)
110
+ device = next(self.model.parameters()).device
111
+
112
+ # Move batch to model device
113
+ batch = {k: v.to(device) for k, v in batch.items()}
114
+
115
+ # 2. Forward Pass
116
+ # start_probs: (B, W), end_probs: (B, W, W)
117
+ start_probs, end_probs = self.model(**batch)
118
+
119
+ # 3. Post-process: Extract Mentions
120
+ results = []
121
+ for i in range(len(docs)):
122
+ doc_mentions = []
123
+ doc_len = len(docs[i])
124
+
125
+ # We only look at the valid word range for this specific document
126
+ # end_probs[i] is a (W, W) matrix where [s, e] is the prob of span s->e
127
+ valid_spans = (end_probs[i][:doc_len, :doc_len] > self.threshold).nonzero()
128
+
129
+ for span in valid_spans:
130
+ start_idx = span[0].item()
131
+ end_idx = span[1].item()
132
+
133
+ # Logic: Only valid if end >= start
134
+ if end_idx >= start_idx:
135
+ score = end_probs[i, start_idx, end_idx].item()
136
+ doc_mentions.append({
137
+ "start": start_idx,
138
+ "end": end_idx,
139
+ "score": round(score, 4),
140
+ "text": " ".join(docs[i][start_idx : end_idx + 1])
141
+ })
142
+
143
+ results.append(doc_mentions)
144
+
145
+ return results
146
+
147
+
148
+ def create_inference_model(repo_id: str, device: str = "cpu"):
149
+ """
150
+ Factory to load a trained model from HF Hub and wrap it for ONNX/Inference.
151
+ """
152
+ # 1. Load the Lightning model (with its weights)
153
+ # Note: Ensure LitMentionDetector is defined in your scope
154
+ fresh_model = make_model_v1()
155
+ lit_model = LitMentionDetector.from_pretrained(
156
+ repo_id,
157
+ tokenizer=fresh_model.tokenizer,
158
+ encoder=fresh_model.encoder,
159
+ mention_detector=fresh_model.mention_detector,
160
+ )
161
+
162
+ # 2. Move to device and set to eval mode
163
+ lit_model.to(device)
164
+ lit_model.eval()
165
+
166
+ # 3. Wrap the core components into the Inference class
167
+ # This separates the 'training' logic from the 'inference' graph
168
+ inference_model = InferenceMentionDetector(
169
+ encoder=lit_model.encoder,
170
+ mention_detector=lit_model.mention_detector
171
+ )
172
+
173
+ # 4. Attach the tokenizer and max_length for the Preprocessor
174
+ # (Optional: helpful for keeping metadata together)
175
+ inference_model.tokenizer = lit_model.tokenizer
176
+ inference_model.max_length = lit_model.encoder.max_length
177
+
178
+ return inference_model.eval()
179
+
180
+
181
+ # TODO
182
+ def compile_inference_model(model):
183
+ return model
184
+
185
+
186
+ repo_id = "kadarakos/mention-detector-poc-dry-run"
187
+ inference_model = compile_inference_model(
188
+ create_inference_model(repo_id)
189
+ )
190
+ pipeline = MentionDetectorPipeline(
191
+ model=inference_model,
192
+ tokenizer=inference_model.tokenizer,
193
+ threshold=0.6 # Noticed that precision is bad below this (still bad :D).
194
+ )
195
+
196
+ docs = [
197
+ "Does this model actually work?".split(),
198
+ "The name of the mage is Bubba.".split(),
199
+ "It was quite a sunny day when the model finally started working.".split(),
200
+ "Albert Einstein was a theoretical physicist who developed the theory of relativity".split(),
201
+ "Apple Inc. and Microsoft are competing in the cloud computing market".split(),
202
+ "New York City is often called the Big Apple".split(),
203
+ "The Great Barrier Reef is the world's largest coral reef system".split(),
204
+ "Marie Curie was the first woman to win a Nobel Prize".split()
205
+ ]
206
+
207
+ batch_mentions = pipeline.predict(docs)
208
+ for i, mentions in enumerate(batch_mentions):
209
+ print(docs[i])
210
+ for mention in mentions:
211
+ print(mention["text"])
src/mentioned/model.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchmetrics
3
+
4
+
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from lightning import LightningModule
8
+
9
+
10
+ class SentenceEncoder(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ model_name: str = "distilroberta-base",
14
+ max_length: int = 512,
15
+ ):
16
+ super().__init__()
17
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
18
+ self.encoder = AutoModel.from_pretrained(model_name)
19
+ self.max_length = max_length
20
+ self.dim = self.encoder.config.hidden_size
21
+ self.stats = {}
22
+
23
+ def forward(self, input_ids, attention_mask, word_ids):
24
+ """
25
+ Args:
26
+ input_ids: (batch, seq_len)
27
+ attention_mask: (batch, seq_len)
28
+ word_ids: (batch, seq_len) -> Pre-computed word indices,
29
+ use -1 for special tokens/padding.
30
+ """
31
+ # 1. Get Transformer Output
32
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
33
+ subword_embeddings = outputs.last_hidden_state # (B, S, D)
34
+ num_words = word_ids.max() + 1
35
+ word_mask = word_ids.unsqueeze(-1) == torch.arange(
36
+ num_words, device=word_ids.device
37
+ )
38
+ word_mask = word_mask.float() # (B, S, W)
39
+ # Sum embeddings for each word: (B, W, S) @ (B, S, D) -> (B, W, D)
40
+ word_sums = torch.bmm(word_mask.transpose(1, 2), subword_embeddings)
41
+ # Count subwords per word to get the denominator
42
+ # (B, W, S) @ (B, S, 1) -> (B, W, 1)
43
+ subword_counts = word_mask.sum(dim=1).unsqueeze(-1).clamp(min=1e-9)
44
+ # (B, W, D)
45
+ word_embeddings = word_sums / subword_counts
46
+ return word_embeddings
47
+
48
+
49
+ class SwiGLU(torch.nn.Module):
50
+ def __init__(self, dim: int, hidden_dim: int = None):
51
+ super().__init__()
52
+ # Common expansion factor
53
+ if hidden_dim is None:
54
+ hidden_dim = 2 * dim
55
+ self.w1 = torch.nn.Linear(dim, hidden_dim)
56
+ self.w3 = torch.nn.Linear(dim, hidden_dim)
57
+ self.w2 = torch.nn.Linear(hidden_dim, dim)
58
+ self.silu = torch.nn.SiLU()
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ gate = self.silu(self.w1(x))
62
+ x = gate * self.w3(x)
63
+ x = self.w2(x)
64
+ return x
65
+
66
+
67
+ class Detector(torch.nn.Module):
68
+ def __init__(self, input_dim: int, hidden_dim: int):
69
+ super().__init__()
70
+ # A 2-layer MLP is standard for span detection to capture interactions
71
+ self.net = torch.nn.Sequential(
72
+ torch.nn.Linear(input_dim, hidden_dim),
73
+ torch.nn.ReLU(),
74
+ torch.nn.Linear(hidden_dim, 1), # Output a single logit per token/pair
75
+ )
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Args:
80
+ x: (B, N, input_dim) for start detection
81
+ (B, N, N, input_dim) for end detection
82
+ Returns:
83
+ logits: (B, N) or (B, N, N)
84
+ """
85
+ return self.net(x)
86
+
87
+
88
+ class MentionDetectorCore(torch.nn.Module):
89
+ def __init__(
90
+ self,
91
+ start_detector: Detector,
92
+ end_detector: Detector,
93
+ ):
94
+ super().__init__()
95
+ self.start_detector = start_detector
96
+ self.end_detector = end_detector
97
+
98
+ def forward(self, emb: torch.Tensor):
99
+ """
100
+ Args:
101
+ emb: (Batch, Seq_Len, Hidden_Dim)
102
+ Returns:
103
+ start_logits: (Batch, Seq_Len)
104
+ end_logits: (Batch, Seq_Len, Seq_Len)
105
+ """
106
+ B, N, H = emb.shape
107
+ start_logits = self.start_detector(emb).squeeze(-1)
108
+ start_rep = emb.unsqueeze(2).expand(-1, -1, N, -1)
109
+ end_rep = emb.unsqueeze(1).expand(-1, N, -1, -1)
110
+ pair_emb = torch.cat([start_rep, end_rep], dim=-1)
111
+ end_logits = self.end_detector(pair_emb).squeeze(-1)
112
+
113
+ return start_logits, end_logits
114
+
115
+
116
+ class LitMentionDetector(LightningModule, PyTorchModelHubMixin):
117
+ def __init__(
118
+ self,
119
+ tokenizer, #: transformers.PreTrainedTokenizer,
120
+ encoder: torch.nn.Module,
121
+ mention_detector: torch.nn.Module,
122
+ lr: float = 2e-5,
123
+ threshold: float = 0.5,
124
+ ):
125
+ super().__init__()
126
+ self.save_hyperparameters(ignore=["encoder", "start_detector", "end_detector"])
127
+ self.tokenizer = tokenizer
128
+ self.encoder = encoder
129
+ # Freeze all encoder parameters
130
+ for param in self.encoder.parameters():
131
+ param.requires_grad = False
132
+ self.mention_detector = mention_detector
133
+
134
+ self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="none")
135
+
136
+ # Two separate metrics for the two tasks
137
+ self.val_f1_start = torchmetrics.classification.BinaryF1Score()
138
+ self.val_f1_end = torchmetrics.classification.BinaryF1Score()
139
+ self.val_f1_mention = torchmetrics.classification.BinaryF1Score()
140
+
141
+ def encode(self, docs: list[list[str]]):
142
+ """
143
+ Handles the non-vectorized tokenization and calls the vectorized encoder.
144
+ """
145
+ device = next(self.parameters()).device
146
+ inputs = self.tokenizer(
147
+ docs,
148
+ is_split_into_words=True,
149
+ return_tensors="pt",
150
+ truncation=True,
151
+ max_length=self.encoder.max_length,
152
+ padding=True,
153
+ return_attention_mask=True,
154
+ return_offsets_mapping=True, # needed for word_ids
155
+ )
156
+ input_ids = inputs["input_ids"].to(device)
157
+ attention_mask = inputs["attention_mask"].to(device)
158
+ batch_word_ids = []
159
+ for i in range(len(docs)):
160
+ w_ids = [w if w is not None else -1 for w in inputs.word_ids(batch_index=i)]
161
+ batch_word_ids.append(torch.tensor(w_ids))
162
+
163
+ word_ids_tensor = torch.stack(batch_word_ids).to(device)
164
+ word_embeddings = self.encoder(
165
+ input_ids=input_ids, attention_mask=attention_mask, word_ids=word_ids_tensor
166
+ )
167
+ return word_embeddings
168
+
169
+ def forward(self, emb: torch.Tensor):
170
+ start_logits, end_logits = self.mention_detector(emb)
171
+ return start_logits, end_logits
172
+
173
+ def _compute_start_loss(self, start_logits, batch):
174
+ targets = batch["starts"].float()
175
+ mask = batch["token_mask"].bool()
176
+ return self.loss_fn(start_logits, targets)[mask].mean()
177
+
178
+ def _compute_end_loss(self, end_logits, batch):
179
+ targets = batch["spans"].float()
180
+ mask = batch["span_loss_mask"].bool()
181
+ raw_loss = self.loss_fn(end_logits, targets)
182
+ relevant_loss = raw_loss[mask]
183
+
184
+ if relevant_loss.numel() == 0:
185
+ return end_logits.sum() * 0
186
+ return relevant_loss.mean()
187
+
188
+ def training_step(self, batch, batch_idx):
189
+ emb = self.encode(batch["sentences"])
190
+ start_logits, end_logits = self.forward(emb)
191
+ loss_start = self._compute_start_loss(start_logits, batch)
192
+ loss_end = self._compute_end_loss(end_logits, batch)
193
+ total_loss = loss_start + loss_end
194
+ self.log_dict(
195
+ {
196
+ "train_loss": total_loss,
197
+ "train_start_loss": loss_start,
198
+ "train_end_loss": loss_end,
199
+ },
200
+ prog_bar=True,
201
+ )
202
+
203
+ return total_loss
204
+
205
+ def validation_step(self, batch, batch_idx):
206
+ emb = self.encode(batch["sentences"])
207
+ start_logits, end_logits = self.forward(emb)
208
+ token_mask = batch["token_mask"].bool()
209
+ span_loss_mask = batch["span_loss_mask"].bool()
210
+
211
+ # --- METRIC 1: Start Detection (Diagnostic) ---
212
+ start_preds = (
213
+ torch.sigmoid(start_logits[token_mask]) > self.hparams.threshold
214
+ ).int()
215
+ start_targets = batch["starts"][token_mask].int()
216
+ if start_targets.numel() > 0:
217
+ self.val_f1_start.update(start_preds, start_targets)
218
+
219
+ # --- METRIC 2: End Detection (Diagnostic / Teacher Forced) ---
220
+ # Evaluates end-detector ONLY on ground-truth start positions
221
+ end_preds_diag = (
222
+ torch.sigmoid(end_logits[span_loss_mask]) > self.hparams.threshold
223
+ ).int()
224
+ end_targets_diag = batch["spans"][span_loss_mask].int()
225
+ if end_targets_diag.numel() > 0:
226
+ self.val_f1_end.update(end_preds_diag, end_targets_diag)
227
+
228
+ # --- METRIC 3: Full Mention Detection (The "Final Boss") ---
229
+ # A mention is correct only if BOTH start and end are predicted correctly.
230
+ # Combined probability: P(Start) * P(End)
231
+ combined_probs = torch.sigmoid(start_logits).unsqueeze(2) * torch.sigmoid(
232
+ end_logits
233
+ )
234
+
235
+ # We evaluate every possible pair in the valid upper triangle of the sentence
236
+ # (Excluding padding and j < i)
237
+ valid_pair_mask = token_mask.unsqueeze(2) & token_mask.unsqueeze(1)
238
+ upper_tri = torch.triu(torch.ones_like(end_logits), diagonal=0).bool()
239
+ mention_eval_mask = valid_pair_mask & upper_tri
240
+
241
+ mention_preds = (
242
+ combined_probs[mention_eval_mask] > self.hparams.threshold
243
+ ).int()
244
+ mention_targets = batch["spans"][mention_eval_mask].int()
245
+
246
+ if mention_targets.numel() > 0:
247
+ self.val_f1_mention.update(mention_preds, mention_targets)
248
+
249
+ # --- 4. Logging ---
250
+ start_loss = self._compute_start_loss(start_logits, batch)
251
+ end_loss = self._compute_end_loss(end_logits, batch)
252
+
253
+ self.log_dict(
254
+ {
255
+ "val_loss": start_loss + end_loss,
256
+ "val_f1_start": self.val_f1_start,
257
+ "val_f1_end": self.val_f1_end,
258
+ "val_f1_mention": self.val_f1_mention,
259
+ },
260
+ prog_bar=True,
261
+ batch_size=len(batch["sentences"]),
262
+ on_epoch=True,
263
+ )
264
+
265
+ @torch.no_grad()
266
+ def predict_mentions(
267
+ self, sentences: list[list[str]], batch_size: int = 2
268
+ ) -> list[list[tuple[int, int]]]:
269
+ """
270
+ Args:
271
+ sentences: A list of tokenized sentences.
272
+ Returns:
273
+ A list (per sentence) of lists containing (start_idx, end_idx) tuples.
274
+ """
275
+ self.eval()
276
+ all_results = []
277
+
278
+ # Process in batches to avoid OOM on large datasets
279
+ for i in range(0, len(sentences), batch_size):
280
+ batch_sentences = sentences[i : i + batch_size]
281
+ emb = self.encoder(batch_sentences) # (B, N, H)
282
+ B, N, _ = emb.shape
283
+ start_logits, end_logits = self.forward(emb)
284
+
285
+ start_probs = torch.sigmoid(start_logits) # (B, N)
286
+ end_probs = torch.sigmoid(end_logits) # (B, N, N)
287
+
288
+ # 3. Calculate Joint Confidence
289
+ # (B, N, 1) * (B, N, N) -> (B, N, N)
290
+ combined_probs = start_probs.unsqueeze(2) * end_probs
291
+
292
+ # 4. Filter by Constraints (Upper Triangle & Threshold)
293
+ # Create mask for j >= i
294
+ upper_tri = torch.triu(
295
+ torch.ones((N, N), device=self.device), diagonal=0
296
+ ).bool()
297
+
298
+ # Apply threshold and upper triangle constraint
299
+ pred_mask = (combined_probs > self.hparams.threshold) & upper_tri
300
+
301
+ # 5. Extract Indices
302
+ # nonzero() returns [batch_idx, start_idx, end_idx]
303
+ indices = pred_mask.nonzero()
304
+
305
+ # Organize results back into a list of lists (one per sentence in batch)
306
+ batch_results = [[] for _ in range(len(batch_sentences))]
307
+ for b_idx, s_idx, e_idx in indices:
308
+ # Convert to standard Python ints for the final output
309
+ batch_results[b_idx.item()].append((s_idx.item(), e_idx.item()))
310
+
311
+ all_results.extend(batch_results)
312
+
313
+ return all_results
314
+
315
+ def test_step(self, batch, batch_idx):
316
+ # Reuse all the logic from validation_step
317
+ return self.validation_step(batch, batch_idx)
318
+
319
+ def configure_optimizers(self):
320
+ return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
321
+
322
+
323
+ def make_model_v1(model_name="distilroberta-base"):
324
+ dim = 768
325
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
326
+ encoder = SentenceEncoder(model_name).train()
327
+ encoder.train()
328
+ start_detector = Detector(dim, dim)
329
+ end_detector = Detector(dim * 2, dim)
330
+ mention_detector = MentionDetectorCore(start_detector, end_detector)
331
+ return LitMentionDetector(tokenizer, encoder, mention_detector)
src/mentioned/py.typed ADDED
File without changes
src/mentioned/train.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+
3
+ from lightning.pytorch.loggers import WandbLogger
4
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
5
+ from lightning.pytorch.callbacks import ModelCheckpoint
6
+ from lightning import Trainer
7
+
8
+ from mentioned.model import make_model_v1, LitMentionDetector
9
+ from mentions.data import make_litbank
10
+
11
+
12
+ def train():
13
+ train_loader, val_loader, test_loader = make_litbank()
14
+ model = make_model_v1()
15
+ wandb_logger = WandbLogger(
16
+ project="mention-detector-poc",
17
+ name="distilroberta-frozen-encoder",
18
+ )
19
+ # Save only the best model for the PoC purposes.
20
+ best_checkpoint = ModelCheckpoint(
21
+ monitor="val_f1_mention",
22
+ mode="max",
23
+ save_top_k=1,
24
+ filename="best-mention-f1",
25
+ verbose=True,
26
+ )
27
+ early_stopper = EarlyStopping(
28
+ monitor="val_f1_mention",
29
+ min_delta=0.01,
30
+ patience=5,
31
+ verbose=True,
32
+ mode="max",
33
+ )
34
+ trainer = Trainer(
35
+ val_check_interval=1000,
36
+ check_val_every_n_epoch=None,
37
+ callbacks=[early_stopper, best_checkpoint],
38
+ logger=wandb_logger,
39
+ )
40
+ trainer.fit(
41
+ model=model,
42
+ train_dataloaders=train_loader,
43
+ val_dataloaders=val_loader,
44
+ )
45
+ trainer.test(dataloaders=test_loader, ckpt_path="best", weights_only=False)
46
+ fresh_model = make_model_v1()
47
+ best_model = LitMentionDetector.load_from_checkpoint(
48
+ trainer.checkpoint_callback.best_model_path,
49
+ tokenizer=fresh_model.tokenizer,
50
+ encoder=fresh_model.encoder,
51
+ mention_detector=fresh_model.mention_detector,
52
+ weights_only=False,
53
+ )
54
+ best_model.push_to_hub("kadarakos/mention-detector-poc-dry-run", private=True)
55
+ wandb.finish()
56
+
57
+ ### Test pull:
58
+ fresh_model = make_model_v1()
59
+ repo_id = "kadarakos/mention-detector-poc-dry-run"
60
+ remote_model = LitMentionDetector.from_pretrained(
61
+ repo_id,
62
+ tokenizer=fresh_model.tokenizer,
63
+ encoder=fresh_model.encoder,
64
+ mention_detector=fresh_model.mention_detector,
65
+ )
66
+
67
+ # 3. Final Verification
68
+ verify_trainer = Trainer(accelerator="auto", logger=False)
69
+ verify_trainer.test(model=remote_model, dataloaders=test_loader)