NeMo / nemo /collections /nlp /models /entity_linking /entity_linking_model.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
import torch
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from transformers import AutoTokenizer
from nemo.collections.common.losses import MultiSimilarityLoss
from nemo.collections.nlp.data import EntityLinkingDataset
from nemo.collections.nlp.models.nlp_model import NLPModel
from nemo.core.classes.common import typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.neural_types import LogitsType, NeuralType
from nemo.utils import logging
__all__ = ['EntityLinkingModel']
class EntityLinkingModel(NLPModel, Exportable):
"""
Second stage pretraining of BERT based language model
for entity linking task. An implementation of Liu et. al's
NAACL 2021 paper Self-Alignment Pretraining for Biomedical Entity Representations.
"""
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"logits": NeuralType(('B', 'D'), LogitsType())}
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
"""Initializes the SAP-BERT model for entity linking."""
# tokenizer needed before super().__init__() so dataset and loader can process data
self._setup_tokenizer(cfg.tokenizer)
super().__init__(cfg=cfg, trainer=trainer)
# Token to use for the self-alignment loss, typically the first token, [CLS]
self._idx_conditioned_on = 0
self.loss = MultiSimilarityLoss()
def _setup_tokenizer(self, cfg: DictConfig):
tokenizer = AutoTokenizer.from_pretrained(
cfg.tokenizer_name, vocab_file=cfg.vocab_file, do_lower_case=cfg.do_lower_case
)
self.tokenizer = tokenizer
@typecheck()
def forward(self, input_ids, token_type_ids, attention_mask):
hidden_states = self.bert_model(
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
)
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
# normalize to unit sphere
logits = torch.nn.functional.normalize(hidden_states[:, self._idx_conditioned_on], p=2, dim=1)
return logits
def training_step(self, batch, batch_idx):
"""
Lightning calls this inside the training loop with the data from the training dataloader
passed in as `batch`.
"""
input_ids, token_type_ids, attention_mask, concept_ids = batch
logits = self.forward(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
train_loss = self.loss(logits=logits, labels=concept_ids)
# No hard examples found in batch,
# shouldn't use this batch to update model weights
if train_loss == 0:
train_loss = None
lr = None
else:
lr = self._optimizer.param_groups[0]["lr"]
self.log("train_loss", train_loss)
self.log("lr", lr, prog_bar=True)
return {"loss": train_loss, "lr": lr}
def validation_step(self, batch, batch_idx):
"""
Lightning calls this inside the validation loop with the data from the validation dataloader
passed in as `batch`.
"""
input_ids, input_type_ids, input_mask, concept_ids = batch
with torch.no_grad():
logits = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)
val_loss = self.loss(logits=logits, labels=concept_ids)
# No hard examples found in batch,
# val loss not used to update model weights
if val_loss == 0:
val_loss = None
else:
self.log("val_loss", val_loss)
logging.info(f"val loss: {val_loss}")
return {"val_loss": val_loss}
def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs.
Args:
outputs: list of individual outputs of each validation step.
Returns:
"""
if outputs:
avg_loss = torch.stack([x["val_loss"] for x in outputs if x["val_loss"] != None]).mean()
self.log(f"val_loss", avg_loss, prog_bar=True)
return {"val_loss": avg_loss}
def setup_training_data(self, train_data_config: Optional[DictConfig]):
if not train_data_config or not train_data_config.data_file:
logging.info(
f"Dataloader config or file_path or processed data path for the train dataset is missing, \
so no data loader for train is created!"
)
self._train_dl = None
return
self._train_dl = self.setup_dataloader(cfg=train_data_config)
def setup_validation_data(self, val_data_config: Optional[DictConfig]):
if not val_data_config or not val_data_config.data_file:
logging.info(
f"Dataloader config or file_path or processed data path for the val dataset is missing, \
so no data loader for validation is created!"
)
self._validation_dl = None
return
self._validation_dl = self.setup_dataloader(cfg=val_data_config)
def setup_dataloader(self, cfg: Dict, is_index_data: bool = False) -> 'torch.utils.data.DataLoader':
dataset = EntityLinkingDataset(
tokenizer=self.tokenizer,
data_file=cfg.data_file,
max_seq_length=cfg.max_seq_length,
is_index_data=is_index_data,
)
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=cfg.batch_size,
collate_fn=dataset.collate_fn,
shuffle=cfg.get("shuffle", True),
num_workers=cfg.get("num_workers", 2),
pin_memory=cfg.get("pin_memory", False),
drop_last=cfg.get("drop_last", False),
)
@classmethod
def list_available_models(cls) -> Optional[Dict[str, str]]:
pass
@classmethod
def from_pretrained(cls, name: str):
pass