Spaces:
Sleeping
Sleeping
| from typing import List, Dict | |
| import torch | |
| from torch import nn | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| except ImportError: | |
| from ditk import logging | |
| logging.warning("not found transformer, please install it using: pip install transformers") | |
| from ding.utils import MODEL_REGISTRY | |
| class LanguageTransformer(nn.Module): | |
| """ | |
| Overview: | |
| The LanguageTransformer network. Download a pre-trained language model and add head on it. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "bert-base-uncased", | |
| add_linear: bool = False, | |
| embedding_size: int = 128, | |
| freeze_encoder: bool = True | |
| ) -> None: | |
| """ | |
| Overview: | |
| Init the LanguageTransformer Model according to input arguments. | |
| Arguments: | |
| - model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased". | |
| - add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \ | |
| ``False``. | |
| - embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128. | |
| - freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \ | |
| defaults to be ``True``. | |
| """ | |
| super().__init__() | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForTokenClassification.from_pretrained(model_name) | |
| # Freeze transformer encoder and only train the linear layer | |
| if freeze_encoder: | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| if add_linear: | |
| # Add a small, adjustable linear layer on top of language model tuned through RL | |
| self.embedding_size = embedding_size | |
| self.linear = nn.Linear( | |
| self.model.config.hidden_size, embedding_size | |
| ) # 768 for bert-base-uncased, distilbert-base-uncased | |
| else: | |
| self.linear = None | |
| def _calc_embedding(self, x: list) -> torch.Tensor: | |
| # ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer, | |
| # the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach | |
| # the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is | |
| # exactly ``max_length``, which can enable batch-wise computing. | |
| input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device) | |
| output = self.model(**input, output_hidden_states=True) | |
| # Get last layer hidden states | |
| last_hidden_states = output.hidden_states[-1] | |
| # Get [CLS] hidden states | |
| sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size | |
| if self.linear: | |
| sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size | |
| return sentence_embedding | |
| def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict: | |
| """ | |
| Overview: | |
| LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores. | |
| Arguments: | |
| - train_samples (:obj:`List[str]`): One list of strings. | |
| - candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores. | |
| Returns: | |
| - output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \ | |
| corresponding ``torch.distributions.Categorical`` object. | |
| Examples: | |
| >>> test_pids = [1] | |
| >>> cand_pids = [0, 2, 4] | |
| >>> problems = [ \ | |
| "This is problem 0", "This is the first question", "Second problem is here", "Another problem", \ | |
| "This is the last problem" \ | |
| ] | |
| >>> ctxt_list = [problems[pid] for pid in test_pids] | |
| >>> cands_list = [problems[pid] for pid in cand_pids] | |
| >>> model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256) | |
| >>> scores = model(ctxt_list, cands_list) | |
| >>> assert scores.shape == (1, 3) | |
| """ | |
| prompt_embedding = self._calc_embedding(train_samples) | |
| cands_embedding = self._calc_embedding(candidate_samples) | |
| scores = torch.mm(prompt_embedding, cands_embedding.t()) | |
| return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores} | |