rhulsker commited on
Commit
3b434b5
·
1 Parent(s): c8307e3

adding inference example

Browse files
Files changed (1) hide show
  1. inference.py +75 -0
inference.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import numpy as np
3
+ import torch
4
+
5
+ model = AutoModelForSequenceClassification.from_pretrained("DavinciTech/BERT_Categorizer")
6
+ tokenizer = AutoTokenizer.from_pretrained("DavinciTech/BERT_Categorizer")
7
+
8
+ model.to("cuda")
9
+
10
+ input_texts = [
11
+
12
+ ("Title: Scanner not working\n"
13
+ "Description: Good morning Team My scanner is not connecting to the image saving folder I assume it has "
14
+ "something to do with the merge from last week I need to scan to be able to do payroll which needs to all be done "
15
+ "before and scanning is the first step"),
16
+
17
+ ("Title: New mouse please\n"
18
+ "Description: My mouse is acting up a little bit could I get a new one please?"),
19
+
20
+ ("Title: Internet outage\n"
21
+ "Description: The whole internet is down for everyone in the office"),
22
+
23
+ ]
24
+
25
+ id2label = {k: l for k, l in enumerate(model.config.LABEL_DICTIONARY.keys())}
26
+ #label2id = {l: k for k, l in enumerate(model.config.LABEL_DICTIONARY.keys())}
27
+
28
+ # Encode the text
29
+ encoded = tokenizer(input_texts, truncation=True, padding="max_length", max_length=512, return_tensors="pt").to("cuda")
30
+
31
+ # Call the model to predict under the format of logits of 27 classes
32
+ logits = model(**encoded).logits.cpu().detach().numpy()
33
+
34
+ IMPACT_LABELS = ["I1", "I2", "I3", "I4"]
35
+ IMPACT_INDICES = range(0, 4)
36
+ URGENCY_LABELS = ["U1", "U2", "U3", "U4"]
37
+ URGENCY_INDICES = range(4, 8)
38
+ TYPE_LABELS = ["T1", "T2", "T3", "T4", "T5"]
39
+ TYPE_INDICES = range(8, 13)
40
+ ALL_LABELS = IMPACT_LABELS + URGENCY_LABELS + TYPE_LABELS
41
+
42
+ def get_preds_from_logits(logits):
43
+ ret = np.zeros(logits.shape)
44
+
45
+ # The first 5 columns (IMPACT_INDICES) are for Impact. They should be handled with a multiclass approach
46
+ # i.e. we fill 1 to the class with highest probability, and 0 into the other columns
47
+ best_class = np.argmax(logits[:, IMPACT_INDICES], axis=-1)
48
+ ret[list(range(len(ret))), best_class] = 1
49
+
50
+ ret[:, URGENCY_INDICES] = 0 # Initialize all priority indices to 0
51
+ ret[:, TYPE_INDICES] = 0 # Initialize all type indices to 0
52
+
53
+ # Find the index with the maximum value in the PRIORITY_INDICES and set it to 1
54
+ max_priority_index = np.argmax(logits[:, URGENCY_INDICES], axis=-1)
55
+ ret[list(range(len(ret))), max_priority_index + URGENCY_INDICES[0]] = 1
56
+
57
+ # Find the index with the maximum value in the TYPE_INDICES and set it to 1
58
+ max_type_index = np.argmax(logits[:, TYPE_INDICES], axis=-1)
59
+ ret[list(range(len(ret))), max_type_index + TYPE_INDICES[0]] = 1
60
+
61
+ return ret
62
+
63
+
64
+ # Decode the result
65
+ preds = get_preds_from_logits(logits)
66
+ decoded_preds = [[id2label[i] for i, l in enumerate(row) if l == 1] for row in preds]
67
+
68
+ print("\n")
69
+
70
+ for text, pred in zip(input_texts, decoded_preds):
71
+ print(text)
72
+ print("Impact:", [model.config.LABEL_DICTIONARY[l] for l in pred if l.startswith("I")])
73
+ print("Urgency:", [model.config.LABEL_DICTIONARY[l] for l in pred if l.startswith("U")])
74
+ print("Type:", [model.config.LABEL_DICTIONARY[l] for l in pred if l.startswith("T")])
75
+ print("")