Rimi98 commited on
Commit
b0ab615
·
1 Parent(s): f688dec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -5,14 +5,14 @@ import torch, json
5
 
6
  token = AutoTokenizer.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english')
7
 
8
- types = [0,1]
9
 
10
  inf_session = onnxruntime.InferenceSession('classifier_quantized.onnx')
11
  input_name = inf_session.get_inputs()[0].name
12
  output_name = inf_session.get_outputs()[0].name
13
 
14
  def classify(review):
15
- input_ids = token(review)['inputs_ids'][:512]
16
  logits = inf_session.run([output_name],{input_name: [input_ids]})[0]
17
  logits = torch.FloatTensorlogits(logits)
18
  probs = torch.sigmoid(logits)[0]
 
5
 
6
  token = AutoTokenizer.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english')
7
 
8
+ types = [{0:'positive'},{1:'Negative'}]
9
 
10
  inf_session = onnxruntime.InferenceSession('classifier_quantized.onnx')
11
  input_name = inf_session.get_inputs()[0].name
12
  output_name = inf_session.get_outputs()[0].name
13
 
14
  def classify(review):
15
+ input_ids = token(review)['input_ids'][:512]
16
  logits = inf_session.run([output_name],{input_name: [input_ids]})[0]
17
  logits = torch.FloatTensorlogits(logits)
18
  probs = torch.sigmoid(logits)[0]