LogitLoader / handler.py
ColeD0's picture
Upload 3 files
7a24a47 verified
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Any, List
from scipy.special import softmax
import numpy as np
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path="."):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path).to(device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# Get model output
input_text = data.pop("inputs", data)
input_ids = self.tokenizer(input_text, return_tensors="pt").to(device)
model_output = self.model(**input_ids)
# Get best offset (Strips out BOS token in model-agnostic way)
offset = self._best_offset(input_ids['input_ids'], model_output)
self.logits = model_output.logits[0][offset:]
self.inputs = input_ids['input_ids'][0].cpu().numpy()[1:]
# Prep logits
sorted, indicies = self.logits.sort(descending=True)
indicies = indicies.cpu().numpy()
self.sorted = sorted.cpu().detach().numpy()
# Initialize tokens
def parse_tokens(idx):
token_rank = np.where(indicies[idx] == self.inputs[idx])[0][0]
upper_prob = np.sum(softmax(self.sorted[idx])[:token_rank])
return {
"input": self.tokenizer.decode(self.inputs[idx]),
"rank": token_rank,
"prob": upper_prob,
"most_likely": self.tokenizer.decode(self.logits[idx].argmax()),
"position": idx}
tokens = [parse_tokens(idx) for idx in range(len(self.inputs))]
return tokens
@staticmethod
def _best_offset(inputs, outputs):
"""Calculates overlap between input and output tokens"""
MAX_OFFSET = 10 # Tokens allowed to for offsetting
# Get tokens from output
top_outputs = outputs.logits[0].argmax(dim=-1).cpu().numpy()
# Generate match matrix
matches = np.zeros((len(inputs), len(top_outputs)))
for i, input in enumerate(inputs[:MAX_OFFSET]):
for j, output in enumerate(top_outputs[:i]):
if input == output:
matches[j, i] = 1