Modify: requirements.txt
Browse files- lm_scorer/__init__.py +0 -0
- lm_scorer/bin/__init__.py +0 -0
- lm_scorer/bin/cli.py +172 -0
- lm_scorer/models/__init__.py +0 -0
- lm_scorer/models/abc/__init__.py +0 -0
- lm_scorer/models/abc/base.py +103 -0
- lm_scorer/models/abc/batch.py +35 -0
- lm_scorer/models/abc/transformers.py +16 -0
- lm_scorer/models/auto.py +34 -0
- lm_scorer/models/gpt2.py +85 -0
- requirements.txt +0 -1
lm_scorer/__init__.py
ADDED
|
File without changes
|
lm_scorer/bin/__init__.py
ADDED
|
File without changes
|
lm_scorer/bin/cli.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import itertools
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from ..models.auto import AutoLMScorer as LMScorer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_args() -> argparse.Namespace:
|
| 16 |
+
parser = argparse.ArgumentParser(
|
| 17 |
+
description="Get sentences probability using a language model.",
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"sentences_file_path",
|
| 21 |
+
metavar="sentences-file-path",
|
| 22 |
+
type=str,
|
| 23 |
+
help="A file containing sentences to score, one per line."
|
| 24 |
+
" If - is given as filename it reads from stdin instead.",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--model-name",
|
| 28 |
+
"-m",
|
| 29 |
+
type=str,
|
| 30 |
+
default="gpt2",
|
| 31 |
+
help="The pretrained language model to use. Can be one of: %s."
|
| 32 |
+
% ", ".join(LMScorer.supported_model_names()),
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--tokens",
|
| 36 |
+
"-t",
|
| 37 |
+
action="store_true",
|
| 38 |
+
help="If provided it provides the probability of each token of each sentence.",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--log-prob",
|
| 42 |
+
"-lp",
|
| 43 |
+
action="store_true",
|
| 44 |
+
help="If provided log probabilities are returned instead.",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--reduce",
|
| 48 |
+
"-r",
|
| 49 |
+
type=str,
|
| 50 |
+
default="prod",
|
| 51 |
+
help="Reduce strategy applied on token probabilities to get the sentence score."
|
| 52 |
+
" Available strategies are: prod, mean, gmean, hmean.",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--batch-size",
|
| 56 |
+
"-b",
|
| 57 |
+
type=int,
|
| 58 |
+
default=1,
|
| 59 |
+
help="Number of sentences to process in parallel.",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--significant-figures",
|
| 63 |
+
"-sf",
|
| 64 |
+
type=int,
|
| 65 |
+
default=5,
|
| 66 |
+
help="Number of significant figures to use when printing numbers.",
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--cuda",
|
| 70 |
+
type=int,
|
| 71 |
+
default=-1,
|
| 72 |
+
help="If provided it runs the model on the given cuda device.",
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--debug",
|
| 76 |
+
action="store_true",
|
| 77 |
+
help="If provided it provides additional logging in case of errors.",
|
| 78 |
+
)
|
| 79 |
+
return parser.parse_args()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def normalize_args(args: argparse.Namespace) -> None:
|
| 83 |
+
if args.sentences_file_path != "-":
|
| 84 |
+
args.sentences_file_path = os.path.realpath(args.sentences_file_path)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def validate_args(args: argparse.Namespace) -> None:
|
| 88 |
+
if args.sentences_file_path != "-":
|
| 89 |
+
if not os.path.isfile(args.sentences_file_path):
|
| 90 |
+
raise ValueError("The provided sentences file path is invalid.")
|
| 91 |
+
|
| 92 |
+
if args.cuda >= 0 and not torch.cuda.is_available():
|
| 93 |
+
raise ValueError("No Cuda device found.")
|
| 94 |
+
|
| 95 |
+
if args.cuda >= torch.cuda.device_count():
|
| 96 |
+
device_count = torch.cuda.device_count()
|
| 97 |
+
raise ValueError("Invalid Cuda device: %d/%d." % (args.cuda, device_count))
|
| 98 |
+
|
| 99 |
+
if args.batch_size <= 0:
|
| 100 |
+
raise ValueError("The batch size must be positive.")
|
| 101 |
+
|
| 102 |
+
if args.significant_figures <= 0:
|
| 103 |
+
raise ValueError("The number of significant figures must be positive.")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
T1 = TypeVar("T1") # pylint: disable=invalid-name
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def grouper(iterable: Iterable[T1], size: int) -> Generator[List[T1], None, None]:
|
| 110 |
+
it = iter(iterable) # pylint: disable=invalid-name
|
| 111 |
+
while True:
|
| 112 |
+
chunk = list(itertools.islice(it, size))
|
| 113 |
+
if not chunk:
|
| 114 |
+
return
|
| 115 |
+
yield chunk
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def main(args: argparse.Namespace) -> None:
|
| 119 |
+
# pylint: disable=too-many-locals
|
| 120 |
+
if args.sentences_file_path == "-":
|
| 121 |
+
sentences_stream = sys.stdin
|
| 122 |
+
else:
|
| 123 |
+
sentences_stream = open(args.sentences_file_path, "r")
|
| 124 |
+
|
| 125 |
+
sig_fig = args.significant_figures
|
| 126 |
+
batch_size = args.batch_size
|
| 127 |
+
device = torch.device("cuda:%d" % args.cuda if args.cuda >= 0 else "cpu")
|
| 128 |
+
scorer = LMScorer.from_pretrained(
|
| 129 |
+
args.model_name, device=device, batch_size=batch_size
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
buffer_size = args.batch_size * 2
|
| 133 |
+
for sentences in grouper(sentences_stream, buffer_size):
|
| 134 |
+
sentences = [sentence.strip() for sentence in sentences]
|
| 135 |
+
|
| 136 |
+
sent_scores = scorer.sentence_score(
|
| 137 |
+
sentences, log=args.log_prob, reduce=args.reduce
|
| 138 |
+
)
|
| 139 |
+
if args.tokens:
|
| 140 |
+
sent_info = scorer.tokens_score(sentences, log=args.log_prob)
|
| 141 |
+
|
| 142 |
+
sent_num = len(sentences)
|
| 143 |
+
for i in range(sent_num):
|
| 144 |
+
sentence, sent_score = sentences[i], sent_scores[i]
|
| 145 |
+
print(f"%s\t%.{sig_fig}g" % (sentence, sent_score))
|
| 146 |
+
if args.tokens:
|
| 147 |
+
scores, _, tokens = sent_info[i]
|
| 148 |
+
for score, token in zip(scores, tokens):
|
| 149 |
+
print(f"%s\t%.{sig_fig}g" % (token, score))
|
| 150 |
+
print("")
|
| 151 |
+
|
| 152 |
+
if args.sentences_file_path != "-":
|
| 153 |
+
sentences_stream.close()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run() -> None:
|
| 157 |
+
try:
|
| 158 |
+
args = parse_args()
|
| 159 |
+
|
| 160 |
+
normalize_args(args)
|
| 161 |
+
validate_args(args)
|
| 162 |
+
main(args)
|
| 163 |
+
except KeyboardInterrupt:
|
| 164 |
+
print("\nAborted!")
|
| 165 |
+
except Exception as err: # pylint: disable=broad-except
|
| 166 |
+
if args.debug:
|
| 167 |
+
raise
|
| 168 |
+
print("Error: %s" % err)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
run()
|
lm_scorer/models/__init__.py
ADDED
|
File without changes
|
lm_scorer/models/abc/__init__.py
ADDED
|
File without changes
|
lm_scorer/models/abc/base.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LMScorer(ABC):
|
| 10 |
+
def __init__(self, model_name: str, **kwargs: Any) -> None:
|
| 11 |
+
self._build(model_name, kwargs)
|
| 12 |
+
|
| 13 |
+
@overload
|
| 14 |
+
def sentence_score(
|
| 15 |
+
self, text: str, log: bool = False, reduce: str = "prod"
|
| 16 |
+
) -> float:
|
| 17 |
+
...
|
| 18 |
+
|
| 19 |
+
@overload
|
| 20 |
+
def sentence_score(
|
| 21 |
+
self, text: List[str], log: bool = False, reduce: str = "prod"
|
| 22 |
+
) -> List[float]:
|
| 23 |
+
...
|
| 24 |
+
|
| 25 |
+
def sentence_score(
|
| 26 |
+
self, text: Union[str, List[str]], log: bool = False, reduce: str = "prod",
|
| 27 |
+
) -> Union[float, List[float]]:
|
| 28 |
+
sentences = [text] if isinstance(text, str) else text
|
| 29 |
+
scores: List[float] = []
|
| 30 |
+
if len(sentences) == 0:
|
| 31 |
+
return scores
|
| 32 |
+
|
| 33 |
+
outputs = self._tokens_log_prob(sentences)
|
| 34 |
+
for output in outputs:
|
| 35 |
+
log_probs = output[0]
|
| 36 |
+
tlen = log_probs.shape[0]
|
| 37 |
+
|
| 38 |
+
if reduce == "prod":
|
| 39 |
+
score = log_probs.sum()
|
| 40 |
+
elif reduce == "mean":
|
| 41 |
+
score = log_probs.logsumexp(0) - math.log(tlen)
|
| 42 |
+
elif reduce == "gmean":
|
| 43 |
+
score = log_probs.mean(0)
|
| 44 |
+
elif reduce == "hmean":
|
| 45 |
+
score = log_probs.neg().logsumexp(0).neg() + math.log(tlen)
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError("Unrecognized scoring strategy: %s" % reduce)
|
| 48 |
+
if not log:
|
| 49 |
+
score = score.exp()
|
| 50 |
+
|
| 51 |
+
scores.append(score.item())
|
| 52 |
+
|
| 53 |
+
return scores[0] if isinstance(text, str) else scores
|
| 54 |
+
|
| 55 |
+
@overload
|
| 56 |
+
def tokens_score(
|
| 57 |
+
self, text: str, log: bool = False
|
| 58 |
+
) -> Tuple[List[float], List[int], List[str]]:
|
| 59 |
+
...
|
| 60 |
+
|
| 61 |
+
@overload
|
| 62 |
+
def tokens_score(
|
| 63 |
+
self, text: List[str], log: bool = False
|
| 64 |
+
) -> List[Tuple[List[float], List[int], List[str]]]:
|
| 65 |
+
...
|
| 66 |
+
|
| 67 |
+
def tokens_score(
|
| 68 |
+
self, text: Union[str, List[str]], log: bool = False
|
| 69 |
+
) -> Union[
|
| 70 |
+
Tuple[List[float], List[int], List[str]],
|
| 71 |
+
List[Tuple[List[float], List[int], List[str]]],
|
| 72 |
+
]:
|
| 73 |
+
sentences = [text] if isinstance(text, str) else text
|
| 74 |
+
outputs: List[Tuple[List[float], List[int], List[str]]] = []
|
| 75 |
+
if len(sentences) == 0:
|
| 76 |
+
return outputs
|
| 77 |
+
|
| 78 |
+
for log_probs, ids, tokens in self._tokens_log_prob(sentences):
|
| 79 |
+
scores = log_probs if log else log_probs.exp()
|
| 80 |
+
scores = cast(torch.DoubleTensor, scores)
|
| 81 |
+
output = (scores.tolist(), ids.tolist(), tokens)
|
| 82 |
+
outputs.append(output)
|
| 83 |
+
|
| 84 |
+
return outputs[0] if isinstance(text, str) else outputs
|
| 85 |
+
|
| 86 |
+
@classmethod
|
| 87 |
+
def supported_model_names(cls) -> Iterable[str]:
|
| 88 |
+
return cls._supported_model_names()
|
| 89 |
+
|
| 90 |
+
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
|
| 91 |
+
# pylint: disable=attribute-defined-outside-init, unused-argument
|
| 92 |
+
self.model_name = model_name
|
| 93 |
+
|
| 94 |
+
@abstractmethod
|
| 95 |
+
def _tokens_log_prob(
|
| 96 |
+
self, text: List[str]
|
| 97 |
+
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
|
| 98 |
+
... # pragma: no cover
|
| 99 |
+
|
| 100 |
+
@classmethod
|
| 101 |
+
@abstractmethod
|
| 102 |
+
def _supported_model_names(cls) -> Iterable[str]:
|
| 103 |
+
... # pragma: no cover
|
lm_scorer/models/abc/batch.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=abstract-method
|
| 2 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
| 3 |
+
from abc import abstractmethod
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .base import LMScorer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BatchedLMScorer(LMScorer):
|
| 11 |
+
# @overrides
|
| 12 |
+
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
|
| 13 |
+
super()._build(model_name, options)
|
| 14 |
+
|
| 15 |
+
batch_size = options.get("batch_size", 1)
|
| 16 |
+
if batch_size < 1:
|
| 17 |
+
raise ValueError("The batch_size option must be positive")
|
| 18 |
+
# pylint: disable=attribute-defined-outside-init
|
| 19 |
+
self.batch_size = batch_size
|
| 20 |
+
|
| 21 |
+
# @overrides
|
| 22 |
+
def _tokens_log_prob(
|
| 23 |
+
self, text: List[str]
|
| 24 |
+
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
|
| 25 |
+
outputs = []
|
| 26 |
+
for i in range(0, len(text), self.batch_size):
|
| 27 |
+
batch = text[i : i + self.batch_size]
|
| 28 |
+
outputs.extend(self._tokens_log_prob_for_batch(batch))
|
| 29 |
+
return outputs
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def _tokens_log_prob_for_batch(
|
| 33 |
+
self, text: List[str]
|
| 34 |
+
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
|
| 35 |
+
... # pragma: no cover
|
lm_scorer/models/abc/transformers.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=abstract-method
|
| 2 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from .batch import BatchedLMScorer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TransformersLMScorer(BatchedLMScorer):
|
| 10 |
+
# @overrides
|
| 11 |
+
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
|
| 12 |
+
super()._build(model_name, options)
|
| 13 |
+
|
| 14 |
+
# Make transformers cache path configurable.
|
| 15 |
+
cache_dir = os.environ.get("TRANSFORMERS_CACHE_DIR", ".transformers_cache")
|
| 16 |
+
options["cache_dir"] = options.get("cache_dir", cache_dir)
|
lm_scorer/models/auto.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
|
| 5 |
+
from .abc.base import LMScorer
|
| 6 |
+
from .gpt2 import GPT2LMScorer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AutoLMScorer:
|
| 10 |
+
MODEL_CLASSES = [GPT2LMScorer]
|
| 11 |
+
|
| 12 |
+
def __init__(self):
|
| 13 |
+
raise EnvironmentError(
|
| 14 |
+
"AutoLMscorer is designed to be instantiated "
|
| 15 |
+
"using the `AutoLMscorer.from_pretrained(model_name)`"
|
| 16 |
+
"method"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
@classmethod
|
| 20 |
+
def from_pretrained(cls, model_name: str, **kwargs: Any) -> LMScorer:
|
| 21 |
+
for model_class in cls.MODEL_CLASSES:
|
| 22 |
+
if model_name not in model_class.supported_model_names():
|
| 23 |
+
continue
|
| 24 |
+
return model_class(model_name, **kwargs)
|
| 25 |
+
raise ValueError(
|
| 26 |
+
"Unrecognized model name."
|
| 27 |
+
"Can be one of: %s" % ", ".join(cls.supported_model_names()),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def supported_model_names(cls) -> Iterable[str]:
|
| 32 |
+
classes = cls.MODEL_CLASSES
|
| 33 |
+
models = map(lambda c: c.supported_model_names(), classes)
|
| 34 |
+
return itertools.chain.from_iterable(models)
|
lm_scorer/models/gpt2.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoTokenizer, GPT2LMHeadModel
|
| 6 |
+
from transformers import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
| 7 |
+
from transformers.tokenization_utils import BatchEncoding
|
| 8 |
+
|
| 9 |
+
from .abc.transformers import TransformersLMScorer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GPT2LMScorer(TransformersLMScorer):
|
| 13 |
+
# @overrides
|
| 14 |
+
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
|
| 15 |
+
super()._build(model_name, options)
|
| 16 |
+
|
| 17 |
+
# pylint: disable=attribute-defined-outside-init
|
| 18 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 19 |
+
model_name, use_fast=True, add_special_tokens=False
|
| 20 |
+
)
|
| 21 |
+
# Add the pad token to GPT2 dictionary.
|
| 22 |
+
# len(tokenizer) = vocab_size + 1
|
| 23 |
+
self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|pad|>"]})
|
| 24 |
+
self.tokenizer.pad_token = "<|pad|>"
|
| 25 |
+
|
| 26 |
+
self.model = GPT2LMHeadModel.from_pretrained(model_name)
|
| 27 |
+
# We need to resize the embedding layer because we added the pad token.
|
| 28 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
| 29 |
+
self.model.eval()
|
| 30 |
+
if "device" in options:
|
| 31 |
+
self.model.to(options["device"])
|
| 32 |
+
|
| 33 |
+
def _add_special_tokens(self, text: str) -> str:
|
| 34 |
+
return self.tokenizer.bos_token + text + self.tokenizer.eos_token
|
| 35 |
+
|
| 36 |
+
# @overrides
|
| 37 |
+
def _tokens_log_prob_for_batch(
|
| 38 |
+
self, text: List[str]
|
| 39 |
+
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
|
| 40 |
+
outputs: List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]] = []
|
| 41 |
+
if len(text) == 0:
|
| 42 |
+
return outputs
|
| 43 |
+
|
| 44 |
+
# TODO: Handle overflowing elements for long sentences
|
| 45 |
+
text = list(map(self._add_special_tokens, text))
|
| 46 |
+
encoding: BatchEncoding = self.tokenizer.batch_encode_plus(
|
| 47 |
+
text, return_tensors="pt",
|
| 48 |
+
)
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
ids = encoding["input_ids"].to(self.model.device)
|
| 51 |
+
attention_mask = encoding["attention_mask"].to(self.model.device)
|
| 52 |
+
nopad_mask = ids != self.tokenizer.pad_token_id
|
| 53 |
+
logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0]
|
| 54 |
+
|
| 55 |
+
for sent_index in range(len(text)):
|
| 56 |
+
sent_nopad_mask = nopad_mask[sent_index]
|
| 57 |
+
# len(tokens) = len(text[sent_index]) + 1
|
| 58 |
+
sent_tokens = [
|
| 59 |
+
tok
|
| 60 |
+
for i, tok in enumerate(encoding.tokens(sent_index))
|
| 61 |
+
if sent_nopad_mask[i] and i != 0
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
# sent_ids.shape = [len(text[sent_index]) + 1]
|
| 65 |
+
sent_ids = ids[sent_index, sent_nopad_mask][1:]
|
| 66 |
+
# logits.shape = [len(text[sent_index]) + 1, vocab_size]
|
| 67 |
+
sent_logits = logits[sent_index, sent_nopad_mask][:-1, :]
|
| 68 |
+
sent_logits[:, self.tokenizer.pad_token_id] = float("-inf")
|
| 69 |
+
# ids_scores.shape = [seq_len + 1]
|
| 70 |
+
sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1)
|
| 71 |
+
# log_prob.shape = [seq_len + 1]
|
| 72 |
+
sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1)
|
| 73 |
+
|
| 74 |
+
sent_log_probs = cast(torch.DoubleTensor, sent_log_probs)
|
| 75 |
+
sent_ids = cast(torch.LongTensor, sent_ids)
|
| 76 |
+
|
| 77 |
+
output = (sent_log_probs, sent_ids, sent_tokens)
|
| 78 |
+
outputs.append(output)
|
| 79 |
+
|
| 80 |
+
return outputs
|
| 81 |
+
|
| 82 |
+
# @overrides
|
| 83 |
+
@classmethod
|
| 84 |
+
def _supported_model_names(cls) -> Iterable[str]:
|
| 85 |
+
return GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()
|
requirements.txt
CHANGED
|
@@ -6,5 +6,4 @@ python-Levenshtein==0.12.2
|
|
| 6 |
fuzzywuzzy==0.18.0
|
| 7 |
tokenizers==0.10.2
|
| 8 |
fsspec==2021.5.0
|
| 9 |
-
lm-scorer==0.4.2 --install-option='--ignore-requires-python'
|
| 10 |
errant
|
|
|
|
| 6 |
fuzzywuzzy==0.18.0
|
| 7 |
tokenizers==0.10.2
|
| 8 |
fsspec==2021.5.0
|
|
|
|
| 9 |
errant
|