Spaces:
Runtime error
Runtime error
first try
Browse files- w2v_ovr_svc.sav → models/w2v_ovr_svc.sav +0 -0
- requirements.txt +5 -0
- text_class_app.py +33 -0
- utils.py +87 -0
w2v_ovr_svc.sav → models/w2v_ovr_svc.sav
RENAMED
|
File without changes
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit==1.4.0
|
| 2 |
+
re==2.2.1
|
| 3 |
+
gensim==4.1.2
|
| 4 |
+
transformers==4.16.1
|
| 5 |
+
pickle
|
text_class_app.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import utils
|
| 3 |
+
|
| 4 |
+
########## Title for the Web App ##########
|
| 5 |
+
st.title("Text Classification for Service Feedback")
|
| 6 |
+
|
| 7 |
+
########## Create Input field ##########
|
| 8 |
+
feedback = st.text_input('Type your text here', 'The staff were extremely polite and helpful!')
|
| 9 |
+
|
| 10 |
+
if st.button('Click for predictions!'):
|
| 11 |
+
with st.spinner('Generating predictions...'):
|
| 12 |
+
|
| 13 |
+
result = get_single_prediction(feedback)
|
| 14 |
+
|
| 15 |
+
st.success(f'Your text has been predicted to fall under the following labels: {result[:-1]}. This text is {result[-1]}.')
|
| 16 |
+
|
| 17 |
+
st.text('Or... Upload a csv file if you have many texts')
|
| 18 |
+
|
| 19 |
+
uploaded_file = st.file_uploader("Please upload a csv file with only 1 column of texts.")
|
| 20 |
+
|
| 21 |
+
if uploaded_file is not None:
|
| 22 |
+
|
| 23 |
+
with st.spinner('Generating predictions...'):
|
| 24 |
+
results = get_multiple_predictions(uploaded_file)
|
| 25 |
+
|
| 26 |
+
st.download_button(
|
| 27 |
+
label="Download results as CSV",
|
| 28 |
+
data=results,
|
| 29 |
+
file_name='results.csv',
|
| 30 |
+
mime='text/csv',
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
utils.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from gensim.models.keyedvectors import KeyedVectors
|
| 3 |
+
from transformers import pipeline
|
| 4 |
+
import pickle
|
| 5 |
+
|
| 6 |
+
w2v = KeyedVectors.load('models/word2vec')
|
| 7 |
+
w2v_vocab = set(sorted(w2v.index_to_key))
|
| 8 |
+
model = pickle.load(open('models/w2v_ovr_svc.sav', 'rb'))
|
| 9 |
+
classifier = pipeline("zero-shot-classification",
|
| 10 |
+
model="facebook/bart-large-mnli", device=0, framework='pt'
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
labels = [
|
| 14 |
+
'communication', 'waiting time',
|
| 15 |
+
'information', 'user interface',
|
| 16 |
+
'facilities', 'location', 'price'
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
def get_sentiment_label_facebook(list_of_sent_dicts):
|
| 20 |
+
if list_of_sent_dicts['labels'][0] == 'negative':
|
| 21 |
+
return 'negative'
|
| 22 |
+
else:
|
| 23 |
+
return 'positive'
|
| 24 |
+
|
| 25 |
+
def get_single_prediction(text):
|
| 26 |
+
|
| 27 |
+
# manipulate data into a format that we pass to our model
|
| 28 |
+
text = text.lower() #lower case
|
| 29 |
+
text = re.sub('[^0-9a-zA-Z\s]', '', text) #remove special char, punctuation
|
| 30 |
+
|
| 31 |
+
# Remove OOV words
|
| 32 |
+
text = ' '.join([i for i in text.split() if i in w2v_vocab])
|
| 33 |
+
|
| 34 |
+
# Vectorise text and store in new dataframe. Sentence vector = average of word vectors
|
| 35 |
+
text_vectors = np.mean([w2v[i] for i in text.split()], axis=0)
|
| 36 |
+
|
| 37 |
+
# Make predictions
|
| 38 |
+
results = model.predict(text_vectors)
|
| 39 |
+
|
| 40 |
+
# Get sentiment
|
| 41 |
+
sentiment = get_sentiment_label_facebook(classifier(text,
|
| 42 |
+
candidate_labels=['positive', 'negative'],
|
| 43 |
+
hypothesis_template='The sentiment of this is {}'))
|
| 44 |
+
|
| 45 |
+
# Consolidate results
|
| 46 |
+
pred_labels = [labels[idx] for idx, tag in enumerate(results) if tag == 1]
|
| 47 |
+
pred_labels.append(sentiment)
|
| 48 |
+
|
| 49 |
+
return pred_labels
|
| 50 |
+
|
| 51 |
+
def get_multiple_predictions(csv):
|
| 52 |
+
|
| 53 |
+
df = pd.read_csv(csv)
|
| 54 |
+
df.columns = ['sequence']
|
| 55 |
+
|
| 56 |
+
df['sequence'] = df['sequence'].str.lower() #lower case
|
| 57 |
+
df['sequence'] = df['sequence'].str.replace('[^0-9a-zA-Z\s]','') #remove special char, punctuation
|
| 58 |
+
|
| 59 |
+
# Remove OOV words
|
| 60 |
+
df['sequence'] = df['sequence'].apply(lambda x: ' '.join([i for i in x.split() if i in w2v_vocab]))
|
| 61 |
+
|
| 62 |
+
# Remove rows with blank string
|
| 63 |
+
invalid = df[(pd.isna(df['sequence'])) | (df['sequence'] == '')]
|
| 64 |
+
|
| 65 |
+
df.dropna(inplace=True)
|
| 66 |
+
df = df[df['sequence'] != ''].reset_index(drop=True)
|
| 67 |
+
|
| 68 |
+
# Vectorise text and store in new dataframe. Sentence vector = average of word vectors
|
| 69 |
+
series_text_vectors = pd.DataFrame(df['sequence'].apply(lambda x: np.mean([w2v[i] for i in x.split()], axis=0)).values.tolist())
|
| 70 |
+
|
| 71 |
+
# Get predictions
|
| 72 |
+
pred_results = pd.DataFrame(model.predict(series_text_vectors), columns = labels)
|
| 73 |
+
|
| 74 |
+
# Join back to original sequence
|
| 75 |
+
final_results = df.join(series_text_vectors)
|
| 76 |
+
|
| 77 |
+
# Get sentiment labels
|
| 78 |
+
final_results['sentiment'] = final_results['sequence'].apply(lambda x: get_sentiment_label_facebook(classifier(x,
|
| 79 |
+
candidate_labels=['positive', 'negative'],
|
| 80 |
+
hypothesis_template='The sentiment of this is {}'))
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Append invalid rows
|
| 84 |
+
if len(invalid) == 0:
|
| 85 |
+
return final_results
|
| 86 |
+
else:
|
| 87 |
+
return pd.concat([final_results, invalid]).reset_index(drop=True)
|