Atquiya-Labiba commited on
Commit
b97935a
·
1 Parent(s): b11bc8f

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -5,6 +5,7 @@ import torch, json
5
 
6
  tokenizer = AutoTokenizer.from_pretrained("roberta-base")
7
 
 
8
  with open("tag_types_encoded.json", "r") as fp:
9
  encode_tag_types = json.load(fp)
10
 
@@ -14,12 +15,20 @@ inf_session = rt.InferenceSession('question-classifier-quantized.onnx')
14
  input_name = inf_session.get_inputs()[0].name
15
  output_name = inf_session.get_outputs()[0].name
16
 
 
 
17
  def classify_question_tags(description):
18
- input_ids = tokenizer(description)['input_ids'][:512]
19
- logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
20
- logits = torch.FloatTensor(logits)
21
- probs = torch.sigmoid(logits)[0]
22
- return dict(zip(tags, map(float, probs)))
 
 
 
 
 
 
23
 
24
  label = gr.Label(num_top_classes=5)
25
  iface = gr.Interface(fn=classify_question_tags, inputs="text", outputs=label)
 
5
 
6
  tokenizer = AutoTokenizer.from_pretrained("roberta-base")
7
 
8
+
9
  with open("tag_types_encoded.json", "r") as fp:
10
  encode_tag_types = json.load(fp)
11
 
 
15
  input_name = inf_session.get_inputs()[0].name
16
  output_name = inf_session.get_outputs()[0].name
17
 
18
+ threshold = 0.5
19
+
20
  def classify_question_tags(description):
21
+ input_ids = tokenizer(description)['input_ids'][:512]
22
+ logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
23
+ probs = torch.sigmoid(torch.FloatTensor(logits))[0]
24
+
25
+ filtered = {tag: float(prob) for tag, prob in zip(tags, probs) if prob > threshold}
26
+
27
+ if not filtered:
28
+ topk = torch.topk(probs, k=5)
29
+ filtered = {tags[i]: float(probs[i]) for i in topk.indices}
30
+
31
+ return filtered
32
 
33
  label = gr.Label(num_top_classes=5)
34
  iface = gr.Interface(fn=classify_question_tags, inputs="text", outputs=label)