roberthsu2003 commited on
Commit
feacdc9
·
verified ·
1 Parent(s): 2396239

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -3
README.md CHANGED
@@ -25,9 +25,43 @@ It achieves the following results on the evaluation set:
25
  - Loss: 1.3109
26
  - Accuracy: 0.5962
27
 
28
- ## Model description
29
-
30
- More information needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  ## Intended uses & limitations
33
 
 
25
  - Loss: 1.3109
26
  - Accuracy: 0.5962
27
 
28
+ ## 模型的使用
29
+
30
+ from typing import Any
31
+ import torch
32
+
33
+ class MultipleChoicePipeline:
34
+ def __init__(self, model, tokenizer) -> None:
35
+ self.model = model
36
+ self.tokenizer = tokenizer
37
+ self.device = model.device
38
+
39
+ def preprocess(self, context, question, choices):
40
+ cs, qcs = [], []
41
+ for choice in choices:
42
+ cs.append(context)
43
+ qcs.append(question + " " + choice)
44
+ return tokenizer(cs, qcs, truncation="only_first", max_length=256, return_tensors="pt")
45
+
46
+ def predict(self, inputs):
47
+ inputs = {k: v.unsqueeze(0).to(self.device) for k, v in inputs.items()}
48
+ return self.model(**inputs).logits
49
+
50
+ def postprocess(self, logits, choices):
51
+ predition = torch.argmax(logits, dim=-1).cpu().item()
52
+ return choices[predition]
53
+
54
+ def __call__(self, context, question, choices) -> Any:
55
+ inputs = self.preprocess(context,question,choices)
56
+ logits = self.predict(inputs)
57
+ result = self.postprocess(logits, choices)
58
+ return result
59
+
60
+ if __name__ == "__main__":
61
+ pipe = MultipleChoicePipeline(model, tokenizer)
62
+ result1 = pipe("國堂在台北上班","國堂在哪裏上班?",['台北','台中'])
63
+ print(result1)
64
+
65
 
66
  ## Intended uses & limitations
67