ColeD0 commited on
Commit
7a24a47
·
verified ·
1 Parent(s): 6963d77

Upload 3 files

Browse files
Files changed (3) hide show
  1. .gitattributes +0 -34
  2. handler.py +63 -0
  3. requirements.txt +3 -0
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
handler.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ from typing import Dict, Any, List
3
+ from scipy.special import softmax
4
+ import numpy as np
5
+ import torch
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ class EndpointHandler():
10
+ def __init__(self, path="."):
11
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
12
+ self.model = AutoModelForCausalLM.from_pretrained(path).to(device)
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
+ """
16
+ data args:
17
+ inputs (:obj: `str`)
18
+ Return:
19
+ A :obj:`list` | `dict`: will be serialized and returned
20
+ """
21
+ # Get model output
22
+ input_text = data.pop("inputs", data)
23
+ input_ids = self.tokenizer(input_text, return_tensors="pt").to(device)
24
+ model_output = self.model(**input_ids)
25
+
26
+ # Get best offset (Strips out BOS token in model-agnostic way)
27
+ offset = self._best_offset(input_ids['input_ids'], model_output)
28
+ self.logits = model_output.logits[0][offset:]
29
+ self.inputs = input_ids['input_ids'][0].cpu().numpy()[1:]
30
+
31
+ # Prep logits
32
+ sorted, indicies = self.logits.sort(descending=True)
33
+ indicies = indicies.cpu().numpy()
34
+ self.sorted = sorted.cpu().detach().numpy()
35
+
36
+ # Initialize tokens
37
+ def parse_tokens(idx):
38
+ token_rank = np.where(indicies[idx] == self.inputs[idx])[0][0]
39
+ upper_prob = np.sum(softmax(self.sorted[idx])[:token_rank])
40
+ return {
41
+ "input": self.tokenizer.decode(self.inputs[idx]),
42
+ "rank": token_rank,
43
+ "prob": upper_prob,
44
+ "most_likely": self.tokenizer.decode(self.logits[idx].argmax()),
45
+ "position": idx}
46
+
47
+ tokens = [parse_tokens(idx) for idx in range(len(self.inputs))]
48
+ return tokens
49
+
50
+ @staticmethod
51
+ def _best_offset(inputs, outputs):
52
+ """Calculates overlap between input and output tokens"""
53
+ MAX_OFFSET = 10 # Tokens allowed to for offsetting
54
+
55
+ # Get tokens from output
56
+ top_outputs = outputs.logits[0].argmax(dim=-1).cpu().numpy()
57
+
58
+ # Generate match matrix
59
+ matches = np.zeros((len(inputs), len(top_outputs)))
60
+ for i, input in enumerate(inputs[:MAX_OFFSET]):
61
+ for j, output in enumerate(top_outputs[:i]):
62
+ if input == output:
63
+ matches[j, i] = 1
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ scikit-learn
2
+ numpy
3
+ accelerate