Spaces:
Sleeping
Sleeping
Chidam Gopal
commited on
directly use the onnx quantized file
Browse files- infer_intent.py +19 -5
- requirements.txt +3 -1
infer_intent.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
| 1 |
-
from transformers import
|
| 2 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
class IntentClassifier:
|
|
@@ -15,10 +19,20 @@ class IntentClassifier:
|
|
| 15 |
self.label2id = {label:id for id,label in self.id2label.items()}
|
| 16 |
|
| 17 |
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def find_intent(self, sequence, verbose=False):
|
| 24 |
inputs = self.tokenizer(sequence,
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer
|
| 2 |
import torch
|
| 3 |
+
import onnxruntime as ort
|
| 4 |
+
import numpy as np
|
| 5 |
+
import requests
|
| 6 |
+
import os
|
| 7 |
|
| 8 |
|
| 9 |
class IntentClassifier:
|
|
|
|
| 19 |
self.label2id = {label:id for id,label in self.id2label.items()}
|
| 20 |
|
| 21 |
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
|
| 22 |
+
model_url = "https://huggingface.co/Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier/resolve/main/onnx/model_quantized.onnx"
|
| 23 |
+
model_dir_path = "models"
|
| 24 |
+
model_path = f"{model_dir_path}/mobilebert-uncased-finetuned-LoRA-intent-classifier_model_quantized.onnx"
|
| 25 |
+
if not os.path.exists(model_dir_path):
|
| 26 |
+
os.makedirs(model_dir_path)
|
| 27 |
+
if not os.path.exists(model_path):
|
| 28 |
+
print("Downloading ONNX model...")
|
| 29 |
+
response = requests.get(model_url)
|
| 30 |
+
with open(model_path, "wb") as f:
|
| 31 |
+
f.write(response.content)
|
| 32 |
+
print("ONNX model downloaded.")
|
| 33 |
+
|
| 34 |
+
# Load the ONNX model
|
| 35 |
+
self.ort_session = ort.InferenceSession(model_path)
|
| 36 |
|
| 37 |
def find_intent(self, sequence, verbose=False):
|
| 38 |
inputs = self.tokenizer(sequence,
|
requirements.txt
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
transformers==4.45.1
|
| 2 |
torch==2.4.1
|
| 3 |
streamlit==1.38.0
|
| 4 |
-
matplotlib==3.9.2
|
|
|
|
|
|
|
|
|
| 1 |
transformers==4.45.1
|
| 2 |
torch==2.4.1
|
| 3 |
streamlit==1.38.0
|
| 4 |
+
matplotlib==3.9.2
|
| 5 |
+
## onnx
|
| 6 |
+
onnxruntime==1.19.2
|