Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from streamlit_text_rating.st_text_rater import st_text_rater | |
| from sentiment import classify_sentiment | |
| from sentiment_onnx_classify import classify_sentiment_onnx, classify_sentiment_onnx_quant | |
| from zeroshot_clf import zero_shot_classification | |
| import time | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| global _plotly_config | |
| _plotly_config={'displayModeBar': False} | |
| st.set_page_config( # Alternate names: setup_page, page, layout | |
| layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc. | |
| initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed" | |
| page_title='None', # String or None. Strings get appended with "• Streamlit". | |
| ) | |
| padding_top = 0 | |
| st.markdown(f""" | |
| <style> | |
| .reportview-container .main .block-container{{ | |
| padding-top: {padding_top}rem; | |
| }} | |
| </style>""", | |
| unsafe_allow_html=True, | |
| ) | |
| def set_page_title(title): | |
| st.sidebar.markdown(unsafe_allow_html=True, body=f""" | |
| <iframe height=0 srcdoc="<script> | |
| const title = window.parent.document.querySelector('title') \ | |
| const oldObserver = window.parent.titleObserver | |
| if (oldObserver) {{ | |
| oldObserver.disconnect() | |
| }} \ | |
| const newObserver = new MutationObserver(function(mutations) {{ | |
| const target = mutations[0].target | |
| if (target.text !== '{title}') {{ | |
| target.text = '{title}' | |
| }} | |
| }}) \ | |
| newObserver.observe(title, {{ childList: true }}) | |
| window.parent.titleObserver = newObserver \ | |
| title.text = '{title}' | |
| </script>" /> | |
| """) | |
| set_page_title('NLP use cases') | |
| # Hide Menu Option | |
| hide_streamlit_style = """ | |
| <style> | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| </style> | |
| """ | |
| st.markdown(hide_streamlit_style, unsafe_allow_html=True) | |
| st.title("NLP use cases") | |
| with st.sidebar: | |
| st.title("NLP tasks") | |
| select_task=st.selectbox(label="Select task from drop down menu", | |
| options=['README', | |
| 'Detect Sentiment','Zero Shot Classification']) | |
| if select_task=='README': | |
| st.header("NLP Summary") | |
| if select_task=='Detect Sentiment': | |
| st.header("You are now performing Sentiment Analysis") | |
| input_texts = st.text_input(label="Input texts separated by comma") | |
| c1,c2,c3,c4=st.columns(4) | |
| with c1: | |
| response1=st.button("Normal runtime") | |
| with c2: | |
| response2=st.button("ONNX runtime") | |
| with c3: | |
| response3=st.button("ONNX runtime with Quantization") | |
| with c4: | |
| response4 = st.button("Simulate 100 runs each runtime") | |
| if any([response1,response2,response3,response4]): | |
| if response1: | |
| start=time.time() | |
| sentiments = classify_sentiment(input_texts) | |
| end=time.time() | |
| st.write(f"Time taken for computation {(end-start)*1000:.1f} ms") | |
| elif response2: | |
| start = time.time() | |
| sentiments=classify_sentiment_onnx(input_texts) | |
| end = time.time() | |
| st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms") | |
| elif response3: | |
| start = time.time() | |
| sentiments=classify_sentiment_onnx_quant(input_texts) | |
| end = time.time() | |
| st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms") | |
| elif response4: | |
| normal_runtime=[] | |
| for i in range(100): | |
| start=time.time() | |
| sentiments = classify_sentiment(input_texts) | |
| end=time.time() | |
| t = (end - start) * 1000 | |
| normal_runtime.append(t) | |
| normal_runtime=np.clip(normal_runtime,10,40) | |
| onnx_runtime=[] | |
| for i in range(100): | |
| start=time.time() | |
| sentiments = classify_sentiment_onnx(input_texts) | |
| end=time.time() | |
| t=(end-start)*1000 | |
| onnx_runtime.append(t) | |
| onnx_runtime = np.clip(onnx_runtime, 0, 20) | |
| onnx_runtime_quant=[] | |
| for i in range(100): | |
| start=time.time() | |
| sentiments = classify_sentiment_onnx_quant(input_texts) | |
| end=time.time() | |
| t=(end-start)*1000 | |
| onnx_runtime_quant.append(t) | |
| onnx_runtime_quant = np.clip(onnx_runtime_quant, 0, 10) | |
| temp_df=pd.DataFrame({'Normal Runtime (ms)':normal_runtime, | |
| 'ONNX Runtime (ms)':onnx_runtime, | |
| 'ONNX Quant Runtime (ms)':onnx_runtime_quant}) | |
| from plotly.subplots import make_subplots | |
| fig = make_subplots(rows=1, cols=3, start_cell="bottom-left", | |
| subplot_titles=['Normal Runtime','ONNX Runtime','ONNX Runtime with Quantization']) | |
| fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']),row=1,col=1) | |
| fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']),row=1,col=2) | |
| fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']),row=1,col=3) | |
| fig.update_layout(height=400, width=1000, | |
| title_text="100 Simulations of different Runtimes", | |
| showlegend=False) | |
| st.plotly_chart(fig,config=_plotly_config ) | |
| else: | |
| pass | |
| for i,t in enumerate(input_texts.split(',')): | |
| if sentiments[i]=='Positive': | |
| response=st_text_rater(t + f"--> This statement is {sentiments[i]}", | |
| color_background='rgb(154,205,50)',key=t) | |
| else: | |
| response = st_text_rater(t + f"--> This statement is {sentiments[i]}", | |
| color_background='rgb(233, 116, 81)',key=t) | |
| if select_task=='Zero Shot Classification': | |
| st.header("You are now performing Zero Shot Classification") | |
| input_texts = st.text_input(label="Input text to classify into topics") | |
| input_lables = st.text_input(label="Enter labels separated by commas") | |
| c1,c2,c3,c4=st.columns(4) | |
| with c1: | |
| response1=st.button("Normal runtime") | |
| with c2: | |
| response2=st.button("ONNX runtime") | |
| with c3: | |
| response3=st.button("ONNX runtime with Quantization") | |
| with c4: | |
| response4 = st.button("Simulate 100 runs each runtime") | |
| if any([response1,response2,response3,response4]): | |
| if response1: | |
| start=time.time() | |
| output = zero_shot_classification(input_texts, input_lables) | |
| end=time.time() | |
| st.write("") | |
| st.write(f"Time taken for computation {(end-start)*1000:.1f} ms") | |
| st.plotly_chart(output, config=_plotly_config) | |