JesseStover commited on
Commit
7ffe52f
·
verified ·
1 Parent(s): 2a75178

Upload code/inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/inference.py +34 -0
code/inference.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sagemaker_inference import encoder
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForMultipleChoice
4
+
5
+
6
+ def model_fn(model_dir):
7
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
8
+ model = AutoModelForMultipleChoice.from_pretrained(model_dir)
9
+ return {"model": model, "tokenizer": tokenizer}
10
+
11
+
12
+ def predict_fn(data, model):
13
+ prompt = data["prompt"]
14
+ candidates = data["candidates"]
15
+
16
+ inputs = model["tokenizer"](
17
+ [[prompt, candidate] for candidate in candidates],
18
+ return_tensors="pt",
19
+ padding=True
20
+ )
21
+
22
+ labels = torch.tensor(0).unsqueeze(0)
23
+
24
+ with torch.no_grad():
25
+ outputs = model(
26
+ **{k: v.unsqueeze(0) for k, v in inputs.items()}, labels=labels
27
+ )
28
+
29
+ return outputs.logits
30
+
31
+
32
+ def output_fn(prediction, content_type):
33
+ result = {i: x for i, x in enumerate(prediction)}
34
+ return encoder.encode(result, content_type)