Spaces:
Runtime error
Runtime error
Commit
·
84fa2e9
1
Parent(s):
8bb7965
change threading options for onnx inference
Browse files
app.py
CHANGED
|
@@ -87,6 +87,10 @@ hide_streamlit_style = """
|
|
| 87 |
"""
|
| 88 |
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
| 92 |
def create_model_dir(chkpt, model_dir):
|
|
@@ -180,6 +184,9 @@ if select_task=='README':
|
|
| 180 |
if select_task == 'Detect Sentiment':
|
| 181 |
t1=time.time()
|
| 182 |
tokenizer_sentiment,sentiment_session = sentiment_task_selected(task=select_task)
|
|
|
|
|
|
|
|
|
|
| 183 |
t2 = time.time()
|
| 184 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
| 185 |
|
|
@@ -210,7 +217,9 @@ if select_task == 'Detect Sentiment':
|
|
| 210 |
|
| 211 |
if select_task=='Zero Shot Classification':
|
| 212 |
t1=time.time()
|
| 213 |
-
tokenizer_zs,
|
|
|
|
|
|
|
| 214 |
t2 = time.time()
|
| 215 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
| 216 |
|
|
@@ -225,7 +234,7 @@ if select_task=='Zero Shot Classification':
|
|
| 225 |
|
| 226 |
if response1:
|
| 227 |
start = time.time()
|
| 228 |
-
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=
|
| 229 |
_tokenizer=tokenizer_zs)
|
| 230 |
end = time.time()
|
| 231 |
st.write("")
|
|
|
|
| 87 |
"""
|
| 88 |
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
| 89 |
|
| 90 |
+
options = ort.SessionOptions()
|
| 91 |
+
options.intra_op_num_threads=1
|
| 92 |
+
options.inter_op_num_threads=1
|
| 93 |
+
|
| 94 |
|
| 95 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
| 96 |
def create_model_dir(chkpt, model_dir):
|
|
|
|
| 184 |
if select_task == 'Detect Sentiment':
|
| 185 |
t1=time.time()
|
| 186 |
tokenizer_sentiment,sentiment_session = sentiment_task_selected(task=select_task)
|
| 187 |
+
##below 2 steps are slower as caching is not enabled
|
| 188 |
+
# tokenizer_sentiment = AutoTokenizer.from_pretrained(sent_mdl_dir)
|
| 189 |
+
# sentiment_session = ort.InferenceSession(f"{sent_onnx_mdl_dir}/{sent_onnx_mdl_name}")
|
| 190 |
t2 = time.time()
|
| 191 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
| 192 |
|
|
|
|
| 217 |
|
| 218 |
if select_task=='Zero Shot Classification':
|
| 219 |
t1=time.time()
|
| 220 |
+
tokenizer_zs,session_zs = zs_task_selected(task=select_task)
|
| 221 |
+
# tokenizer_zs= AutoTokenizer.from_pretrained(zs_mdl_dir)
|
| 222 |
+
# session_zs = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}")
|
| 223 |
t2 = time.time()
|
| 224 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
| 225 |
|
|
|
|
| 234 |
|
| 235 |
if response1:
|
| 236 |
start = time.time()
|
| 237 |
+
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=session_zs,
|
| 238 |
_tokenizer=tokenizer_zs)
|
| 239 |
end = time.time()
|
| 240 |
st.write("")
|