Atquiya-Labiba commited on
Commit
d368bde
·
1 Parent(s): b97935a

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -13
app.py CHANGED
@@ -15,20 +15,13 @@ inf_session = rt.InferenceSession('question-classifier-quantized.onnx')
15
  input_name = inf_session.get_inputs()[0].name
16
  output_name = inf_session.get_outputs()[0].name
17
 
18
- threshold = 0.5
19
-
20
  def classify_question_tags(description):
21
- input_ids = tokenizer(description)['input_ids'][:512]
22
- logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
23
- probs = torch.sigmoid(torch.FloatTensor(logits))[0]
24
-
25
- filtered = {tag: float(prob) for tag, prob in zip(tags, probs) if prob > threshold}
26
-
27
- if not filtered:
28
- topk = torch.topk(probs, k=5)
29
- filtered = {tags[i]: float(probs[i]) for i in topk.indices}
30
-
31
- return filtered
32
 
33
  label = gr.Label(num_top_classes=5)
34
  iface = gr.Interface(fn=classify_question_tags, inputs="text", outputs=label)
 
15
  input_name = inf_session.get_inputs()[0].name
16
  output_name = inf_session.get_outputs()[0].name
17
 
 
 
18
  def classify_question_tags(description):
19
+ input_ids = tokenizer(description)['input_ids'][:512]
20
+ logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
21
+ logits = torch.FloatTensor(logits)
22
+ probs = torch.sigmoid(logits)[0]
23
+
24
+ return dict(zip(tags, map(float, probs)))
 
 
 
 
 
25
 
26
  label = gr.Label(num_top_classes=5)
27
  iface = gr.Interface(fn=classify_question_tags, inputs="text", outputs=label)