parth parekh
commited on
Commit
·
838063e
1
Parent(s):
e43c18e
testing out torch jit
Browse files- predictor.py +18 -10
predictor.py
CHANGED
|
@@ -82,21 +82,29 @@ test_sentences = [
|
|
| 82 |
"Lets do '42069' tonight it will be really fun what do you say ?"
|
| 83 |
]
|
| 84 |
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def predict(text):
|
| 88 |
-
with torch.
|
| 89 |
-
inputs = torch.tensor([text_pipeline(text)])
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
inputs = torch.cat([inputs,
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
| 96 |
return torch.argmax(outputs, dim=1).item()
|
| 97 |
|
| 98 |
|
| 99 |
-
|
| 100 |
# Test the sentences
|
| 101 |
for i, sentence in enumerate(test_sentences, 1):
|
| 102 |
prediction = predict(sentence)
|
|
|
|
| 82 |
"Lets do '42069' tonight it will be really fun what do you say ?"
|
| 83 |
]
|
| 84 |
|
| 85 |
+
# JIT Script the model for faster inference
|
| 86 |
+
scripted_model = torch.jit.script(model)
|
| 87 |
|
| 88 |
+
# Preallocate padding tensor to avoid repeated memory allocation
|
| 89 |
+
MAX_LEN = max(FILTER_SIZES)
|
| 90 |
+
padding_tensor = torch.zeros(1, MAX_LEN, dtype=torch.long).to(device)
|
| 91 |
+
|
| 92 |
+
# Prediction function using JIT and inference optimizations
|
| 93 |
def predict(text):
|
| 94 |
+
with torch.inference_mode(): # Use inference mode instead of no_grad
|
| 95 |
+
inputs = torch.tensor([text_pipeline(text)]).to(device)
|
| 96 |
+
|
| 97 |
+
# Perform padding if necessary
|
| 98 |
+
if inputs.size(1) < MAX_LEN:
|
| 99 |
+
inputs = torch.cat([inputs, padding_tensor[:, :MAX_LEN - inputs.size(1)]], dim=1)
|
| 100 |
+
|
| 101 |
+
# Pass inputs through the scripted model
|
| 102 |
+
outputs = scripted_model(inputs)
|
| 103 |
+
|
| 104 |
+
# Return predicted class
|
| 105 |
return torch.argmax(outputs, dim=1).item()
|
| 106 |
|
| 107 |
|
|
|
|
| 108 |
# Test the sentences
|
| 109 |
for i, sentence in enumerate(test_sentences, 1):
|
| 110 |
prediction = predict(sentence)
|