| | import torch |
| | import torchmetrics |
| |
|
| | from transformers import AutoTokenizer, AutoModel |
| | from huggingface_hub import PyTorchModelHubMixin |
| | from lightning import LightningModule |
| |
|
| | from mentioned.data import DataBlob |
| |
|
| | class ModelRegistry: |
| | _registry = {} |
| |
|
| | @classmethod |
| | def register(cls, name): |
| | def decorator(func): |
| | cls._registry[name] = func |
| | return func |
| | return decorator |
| |
|
| | @classmethod |
| | def get(cls, name): |
| | return cls._registry[name] |
| |
|
| |
|
| | class SentenceEncoder(torch.nn.Module): |
| | def __init__( |
| | self, |
| | model_name: str = "distilroberta-base", |
| | max_length: int = 512, |
| | ): |
| | super().__init__() |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, |
| | use_fast=True, |
| | ) |
| | self.encoder = AutoModel.from_pretrained(model_name) |
| | self.max_length = max_length |
| | self.dim = self.encoder.config.hidden_size |
| | self.stats = {} |
| |
|
| | def forward(self, input_ids, attention_mask, word_ids): |
| | """ |
| | Args: |
| | input_ids: B x N |
| | attention_mask: B x N |
| | word_ids: B x N |
| | """ |
| | outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| | subword_embeddings = outputs.last_hidden_state |
| | num_words = word_ids.max() + 1 |
| | word_mask = word_ids.unsqueeze(-1) == torch.arange( |
| | num_words, device=word_ids.device |
| | ) |
| | word_mask = word_mask.to(subword_embeddings.dtype) |
| | |
| | word_sums = torch.bmm(word_mask.transpose(1, 2), subword_embeddings) |
| | |
| | |
| | subword_counts = word_mask.sum(dim=1).unsqueeze(-1).clamp(min=1e-9) |
| | |
| | word_embeddings = word_sums / subword_counts |
| | return word_embeddings |
| |
|
| |
|
| | class Detector(torch.nn.Module): |
| | def __init__( |
| | self, |
| | input_dim: int, |
| | hidden_dim: int, |
| | num_classes: int = 1, |
| | ): |
| | super().__init__() |
| | self.net = torch.nn.Sequential( |
| | torch.nn.Linear(input_dim, hidden_dim), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(hidden_dim, num_classes), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: (B, N, input_dim) for start detection |
| | (B, N, N, input_dim) for end detection |
| | Returns: |
| | logits: (B, N) or (B, N, N) |
| | """ |
| | return self.net(x) |
| |
|
| |
|
| | class MentionDetectorCore(torch.nn.Module): |
| | def __init__( |
| | self, |
| | start_detector: Detector, |
| | end_detector: Detector, |
| | ): |
| | super().__init__() |
| | self.start_detector = start_detector |
| | self.end_detector = end_detector |
| |
|
| | def forward(self, emb: torch.Tensor): |
| | """ |
| | Args: |
| | emb: (Batch, Seq_Len, Hidden_Dim) |
| | Returns: |
| | start_logits: (Batch, Seq_Len) |
| | end_logits: (Batch, Seq_Len, Seq_Len) |
| | """ |
| | B, N, H = emb.shape |
| | start_logits = self.start_detector(emb).squeeze(-1) |
| | |
| | start_rep = emb.unsqueeze(2).expand(-1, -1, N, -1) |
| | end_rep = emb.unsqueeze(1).expand(-1, N, -1, -1) |
| | pair_emb = torch.cat([start_rep, end_rep], dim=-1) |
| | end_logits = self.end_detector(pair_emb).squeeze(-1) |
| |
|
| | return start_logits, end_logits |
| |
|
| |
|
| | class MentionLabeler(torch.nn.Module): |
| | def __init__(self, classifier: Detector): |
| | super().__init__() |
| | self.classifier = classifier |
| |
|
| | def forward(self, emb: torch.Tensor): |
| | """ |
| | Args: |
| | emb: (Batch, Seq_Len, Hidden_Dim) |
| | Returns: |
| | start_logits: (Batch, Seq_Len) |
| | end_logits: (Batch, Seq_Len, Seq_Len) |
| | """ |
| | B, N, H = emb.shape |
| | |
| | start_rep = emb.unsqueeze(2).expand(-1, -1, N, -1) |
| | end_rep = emb.unsqueeze(1).expand(-1, N, -1, -1) |
| | pair_emb = torch.cat([start_rep, end_rep], dim=-1) |
| | logits = self.classifier(pair_emb).squeeze(-1) |
| |
|
| | return logits |
| | |
| |
|
| | class LitMentionDetector(LightningModule, PyTorchModelHubMixin): |
| | def __init__( |
| | self, |
| | tokenizer, |
| | encoder: torch.nn.Module, |
| | mention_detector: torch.nn.Module, |
| | mention_labeler: torch.nn.Module | None = None, |
| | label2id: dict | None = None, |
| | lr: float = 2e-5, |
| | threshold: float = 0.5, |
| | ): |
| | super().__init__() |
| | self.save_hyperparameters(ignore=["encoder", "mention_detector", "mention_labeler"]) |
| | self.tokenizer = tokenizer |
| | self.encoder = encoder |
| | |
| | for param in self.encoder.parameters(): |
| | param.requires_grad = False |
| | self.mention_detector = mention_detector |
| | self.mention_labeler = mention_labeler |
| | self.label2id = label2id |
| | self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="none") |
| |
|
| | |
| | self.val_f1_start = torchmetrics.classification.BinaryF1Score() |
| | self.val_f1_end = torchmetrics.classification.BinaryF1Score() |
| | self.val_f1_mention = torchmetrics.classification.BinaryF1Score() |
| |
|
| | if mention_labeler is not None: |
| | if label2id is None: |
| | raise ValueError("Need label2id!") |
| | num_classes = len(self.label2id) |
| | self.val_f1_entity_start = torchmetrics.classification.BinaryF1Score() |
| | self.val_f1_entity_end = torchmetrics.classification.BinaryF1Score() |
| | self.val_f1_entity_mention = torchmetrics.classification.BinaryF1Score() |
| | self.val_f1_entity_labels = torchmetrics.classification.MulticlassF1Score( |
| | num_classes=num_classes, |
| | average="macro" |
| | ) |
| | self.entity_loss = torch.nn.CrossEntropyLoss() |
| | log_2 = torch.log(torch.tensor(2.0)) |
| | |
| | self.entity_weight = log_2 / torch.log(torch.tensor(float(num_classes))) |
| |
|
| | def encode(self, docs: list[list[str]]): |
| | """ |
| | Handles the non-vectorized tokenization and calls the vectorized encoder. |
| | """ |
| | device = next(self.parameters()).device |
| | inputs = self.tokenizer( |
| | docs, |
| | is_split_into_words=True, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=self.encoder.max_length, |
| | padding=True, |
| | return_attention_mask=True, |
| | return_offsets_mapping=True, |
| | ) |
| | input_ids = inputs["input_ids"].to(device) |
| | attention_mask = inputs["attention_mask"].to(device) |
| | batch_word_ids = [] |
| | for i in range(len(docs)): |
| | w_ids = [w if w is not None else -1 for w in inputs.word_ids(batch_index=i)] |
| | batch_word_ids.append(torch.tensor(w_ids)) |
| |
|
| | word_ids_tensor = torch.stack(batch_word_ids).to(device) |
| | word_embeddings = self.encoder( |
| | input_ids=input_ids, attention_mask=attention_mask, word_ids=word_ids_tensor |
| | ) |
| | return word_embeddings |
| | |
| | def forward_detector(self, emb: torch.Tensor): |
| | start_logits, end_logits = self.mention_detector(emb) |
| | return start_logits, end_logits |
| |
|
| | def forward_labeler(self, emb: torch.Tensor): |
| | entity_logits = self.mention_labeler(emb) |
| | return entity_logits |
| |
|
| | def _compute_start_loss(self, start_logits, batch): |
| | targets = batch["starts"].float() |
| | mask = batch["token_mask"].bool() |
| | return self.loss_fn(start_logits, targets)[mask].mean() |
| |
|
| | def _compute_end_loss(self, end_logits, batch): |
| | targets = batch["spans"].float() |
| | mask = batch["span_loss_mask"].bool() |
| | raw_loss = self.loss_fn(end_logits, targets) |
| | relevant_loss = raw_loss[mask] |
| |
|
| | if relevant_loss.numel() == 0: |
| | return end_logits.sum() * 0 |
| | return relevant_loss.mean() |
| |
|
| | def _compute_entity_loss(self, entity_logits, batch): |
| | """ |
| | entity_logits shape: [batch, max_len, max_len, num_classes] |
| | """ |
| | preds = [] |
| | targets = [] |
| | |
| | for b, labels_dict in enumerate(batch["gold_labels"]): |
| | for (s, e), label_str in labels_dict.items(): |
| | |
| | if s < entity_logits.size(1) and e < entity_logits.size(2): |
| | label_id = self.label2id[label_str] |
| | |
| | preds.append(entity_logits[b, s, e]) |
| | targets.append(label_id) |
| |
|
| | if not targets: |
| | |
| | return entity_logits.sum() * 0 |
| |
|
| | |
| | preds_tensor = torch.stack(preds) |
| | targets_tensor = torch.tensor(targets, device=entity_logits.device) |
| |
|
| | |
| | return self.entity_loss(preds_tensor, targets_tensor) |
| | |
| | def training_step(self, batch, batch_idx): |
| | emb = self.encode(batch["sentences"]) |
| | start_logits, end_logits = self.forward_detector(emb) |
| | loss_start = self._compute_start_loss(start_logits, batch) |
| | loss_end = self._compute_end_loss(end_logits, batch) |
| | total_loss = loss_start + loss_end |
| | log_metrics = { |
| | "train_start_loss": loss_start, |
| | "train_end_loss": loss_end, |
| | } |
| | if batch["task_id"][0] == 1: |
| | entity_logits = self.forward_labeler(emb) |
| | loss_entity = self._compute_entity_loss(entity_logits, batch) |
| | log_metrics["train_entity_loss"] = loss_entity |
| | total_loss = total_loss + self.entity_weight * loss_entity |
| |
|
| | |
| | log_metrics["train_loss"] = total_loss |
| | self.log_dict(log_metrics, prog_bar=True) |
| | return total_loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | |
| | emb = self.encode(batch["sentences"]) |
| | start_logits, end_logits = self.forward_detector(emb) |
| | |
| | token_mask = batch["token_mask"].bool() |
| | span_loss_mask = batch["span_loss_mask"].bool() |
| | |
| | |
| | is_start = (torch.sigmoid(start_logits) > self.hparams.threshold).int() |
| | is_end = (torch.sigmoid(end_logits) > self.hparams.threshold).int() |
| | |
| | |
| | valid_pair_mask = token_mask.unsqueeze(2) & token_mask.unsqueeze(1) |
| | upper_tri = torch.triu(torch.ones_like(end_logits), diagonal=0).bool() |
| | mention_eval_mask = valid_pair_mask & upper_tri |
| | |
| | |
| | pred_spans = (is_start.unsqueeze(2) & is_end)[mention_eval_mask] |
| | target_spans = batch["spans"][mention_eval_mask].int() |
| |
|
| | |
| | log_stats = {} |
| |
|
| | |
| | if batch["task_id"][0] == 0: |
| | |
| | if token_mask.any(): |
| | self.val_f1_start.update(is_start[token_mask], batch["starts"][token_mask].int()) |
| | |
| | if span_loss_mask.any(): |
| | self.val_f1_end.update(is_end[span_loss_mask], batch["spans"][span_loss_mask].int()) |
| | |
| | if mention_eval_mask.any(): |
| | self.val_f1_mention.update(pred_spans, target_spans) |
| | |
| | log_stats["val_f1_mention"] = self.val_f1_mention |
| |
|
| | |
| | elif batch["task_id"][0] == 1: |
| | |
| | if token_mask.any(): |
| | self.val_f1_entity_start.update(is_start[token_mask], batch["starts"][token_mask].int()) |
| | |
| | if span_loss_mask.any(): |
| | self.val_f1_entity_end.update(is_end[span_loss_mask], batch["spans"][span_loss_mask].int()) |
| | |
| | if mention_eval_mask.any(): |
| | self.val_f1_entity_mention.update(pred_spans, target_spans) |
| | |
| | log_stats["val_f1_entity_mention"] = self.val_f1_entity_mention |
| |
|
| | |
| | if self.mention_labeler is not None: |
| | entity_logits = self.forward_labeler(emb) |
| | gold_preds, gold_targets = [], [] |
| | |
| | for b, labels_dict in enumerate(batch["gold_labels"]): |
| | for (s, e), label_str in labels_dict.items(): |
| | if s < entity_logits.size(1) and e < entity_logits.size(2): |
| | gold_preds.append(torch.argmax(entity_logits[b, s, e], dim=-1)) |
| | gold_targets.append(self.label2id[label_str]) |
| | |
| | |
| | if gold_targets: |
| | self.val_f1_entity_labels.update( |
| | torch.stack(gold_preds), |
| | torch.tensor(gold_targets, device=emb.device) |
| | ) |
| | log_stats["val_f1_entity_labels"] = self.val_f1_entity_labels |
| |
|
| | |
| | |
| | loss_start = self._compute_start_loss(start_logits, batch) |
| | loss_end = self._compute_end_loss(end_logits, batch) |
| | log_stats["val_loss"] = loss_start + loss_end |
| | |
| | self.log_dict(log_stats, prog_bar=True, on_epoch=True, batch_size=len(batch["sentences"])) |
| |
|
| | @torch.no_grad() |
| | def predict_mentions( |
| | self, sentences: list[list[str]], batch_size: int = 2 |
| | ) -> list[list[tuple[int, int]]]: |
| | self.eval() |
| | all_results = [] |
| | thresh = self.hparams.threshold |
| | for i in range(0, len(sentences), batch_size): |
| | batch_sentences = sentences[i:i + batch_size] |
| | emb = self.encode(batch_sentences) |
| | start_logits, end_logits = self.forward_detector(emb) |
| | is_start = torch.sigmoid(start_logits) > thresh |
| | is_span = torch.sigmoid(end_logits) > thresh |
| | |
| | N = end_logits.size(1) |
| | upper_tri = torch.triu( |
| | torch.ones((N, N), device=self.device), diagonal=0 |
| | ).bool() |
| | pred_mask = is_start.unsqueeze(2) & is_span & upper_tri |
| |
|
| | |
| | indices = pred_mask.nonzero() |
| |
|
| | batch_results = [[] for _ in range(len(batch_sentences))] |
| | for b_idx, s_idx, e_idx in indices: |
| | batch_results[b_idx.item()].append((s_idx.item(), e_idx.item())) |
| |
|
| | all_results.extend(batch_results) |
| |
|
| | return all_results |
| |
|
| | def test_step(self, batch, batch_idx): |
| | |
| | return self.validation_step(batch, batch_idx) |
| |
|
| | def configure_optimizers(self): |
| | return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr) |
| |
|
| |
|
| | @ModelRegistry.register("model_v1") |
| | def make_model_v1(data: DataBlob, model_name="distilroberta-base"): |
| | dim = 768 |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
| | encoder = SentenceEncoder(model_name).train() |
| | encoder.train() |
| | start_detector = Detector(dim, dim) |
| | end_detector = Detector(dim * 2, dim) |
| | mention_detector = MentionDetectorCore(start_detector, end_detector) |
| | return LitMentionDetector(tokenizer, encoder, mention_detector) |
| |
|
| |
|
| | @ModelRegistry.register("model_v2") |
| | def make_model_v2(data: DataBlob, model_name="distilroberta-base"): |
| | label2id = data.label2id |
| | dim = 768 |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
| | encoder = SentenceEncoder(model_name).train() |
| | encoder.train() |
| | start_detector = Detector(dim, dim) |
| | end_detector = Detector(dim * 2, dim) |
| | classifier = Detector(dim * 2, dim, num_classes=len(label2id)) |
| | mention_detector = MentionDetectorCore(start_detector, end_detector) |
| | mention_labeler = MentionLabeler(classifier) |
| | return LitMentionDetector( |
| | tokenizer, |
| | encoder, |
| | mention_detector, |
| | mention_labeler, |
| | label2id, |
| | ) |
| |
|