|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from typing import Dict, Any, List |
|
|
from scipy.special import softmax |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path="."): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
self.model = AutoModelForCausalLM.from_pretrained(path).to(device) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
data args: |
|
|
inputs (:obj: `str`) |
|
|
Return: |
|
|
A :obj:`list` | `dict`: will be serialized and returned |
|
|
""" |
|
|
|
|
|
input_text = data.pop("inputs", data) |
|
|
input_ids = self.tokenizer(input_text, return_tensors="pt").to(device) |
|
|
model_output = self.model(**input_ids) |
|
|
|
|
|
|
|
|
offset = self._best_offset(input_ids['input_ids'], model_output) |
|
|
self.logits = model_output.logits[0][offset:] |
|
|
self.inputs = input_ids['input_ids'][0].cpu().numpy()[1:] |
|
|
|
|
|
|
|
|
sorted, indicies = self.logits.sort(descending=True) |
|
|
indicies = indicies.cpu().numpy() |
|
|
self.sorted = sorted.cpu().detach().numpy() |
|
|
|
|
|
|
|
|
def parse_tokens(idx): |
|
|
token_rank = np.where(indicies[idx] == self.inputs[idx])[0][0] |
|
|
upper_prob = np.sum(softmax(self.sorted[idx])[:token_rank]) |
|
|
return { |
|
|
"input": self.tokenizer.decode(self.inputs[idx]), |
|
|
"rank": token_rank, |
|
|
"prob": upper_prob, |
|
|
"most_likely": self.tokenizer.decode(self.logits[idx].argmax()), |
|
|
"position": idx} |
|
|
|
|
|
tokens = [parse_tokens(idx) for idx in range(len(self.inputs))] |
|
|
return tokens |
|
|
|
|
|
@staticmethod |
|
|
def _best_offset(inputs, outputs): |
|
|
"""Calculates overlap between input and output tokens""" |
|
|
MAX_OFFSET = 10 |
|
|
|
|
|
|
|
|
top_outputs = outputs.logits[0].argmax(dim=-1).cpu().numpy() |
|
|
|
|
|
|
|
|
matches = np.zeros((len(inputs), len(top_outputs))) |
|
|
for i, input in enumerate(inputs[:MAX_OFFSET]): |
|
|
for j, output in enumerate(top_outputs[:i]): |
|
|
if input == output: |
|
|
matches[j, i] = 1 |