Spaces:
Runtime error
Runtime error
mlkorra
commited on
Commit
·
7dfe9b5
1
Parent(s):
96a99b3
Update App
Browse files
app.py
CHANGED
|
@@ -23,23 +23,25 @@ import os
|
|
| 23 |
|
| 24 |
def visualizer(prob_req, embed, df, index, company_name):
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def clean_text(text):
|
| 45 |
|
|
@@ -67,30 +69,31 @@ def preprocess(name, group, state, states_used, desc):
|
|
| 67 |
@st.cache(persist=True,suppress_st_warning=True)
|
| 68 |
def load_topic_model(model_path, name, group, state, states_used, desc):
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
| 94 |
|
| 95 |
def app():
|
| 96 |
|
|
@@ -115,7 +118,7 @@ def app():
|
|
| 115 |
#state= st.selectbox('Select state the company is based in', states)
|
| 116 |
#states_used = st.multiselect('Select states you want to analyse', states)
|
| 117 |
|
| 118 |
-
examples = [['Coursera','Education','California',['California','
|
| 119 |
|
| 120 |
if check_examples:
|
| 121 |
example = examples[0]
|
|
@@ -137,13 +140,13 @@ def app():
|
|
| 137 |
states_used = st.multiselect('Select states you want to analyse', states)
|
| 138 |
|
| 139 |
if(st.button("Analyse Competition")):
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
|
| 148 |
|
| 149 |
if __name__ == "__main__":
|
|
|
|
| 23 |
|
| 24 |
def visualizer(prob_req, embed, df, index, company_name):
|
| 25 |
|
| 26 |
+
with st.spinner("Visualizing the results !!!"):
|
| 27 |
+
|
| 28 |
+
fname = 'topicmodel/saving_example.sav'
|
| 29 |
+
reducer= pickle.load((open(fname, 'rb'))) #load the umap dimensionality reduction model trained on rest of probablities
|
| 30 |
+
embed_req= reducer.transform(prob_req)
|
| 31 |
+
|
| 32 |
+
#add scatter plot for all embeddings from our dataset
|
| 33 |
+
fig1 = px.scatter(
|
| 34 |
+
embed, x=0, y=1,
|
| 35 |
+
color=df.iloc[index]['headquarters'], labels={'color': 'states'}, hover_name= df.iloc[index]['company_name'] + " with industry group: "+ df.iloc[index]['industry_groups'])
|
| 36 |
+
#add the data for users request and display
|
| 37 |
+
fig1.add_trace(
|
| 38 |
+
go.Scatter(
|
| 39 |
+
x=embed_req[:,0],
|
| 40 |
+
y=embed_req[:,1],
|
| 41 |
+
mode='markers',
|
| 42 |
+
marker_symbol="hexagon2", marker_size=15,
|
| 43 |
+
showlegend=True, name= company_name, hovertext= company_name))
|
| 44 |
+
st.plotly_chart(fig1)
|
| 45 |
|
| 46 |
def clean_text(text):
|
| 47 |
|
|
|
|
| 69 |
@st.cache(persist=True,suppress_st_warning=True)
|
| 70 |
def load_topic_model(model_path, name, group, state, states_used, desc):
|
| 71 |
|
| 72 |
+
with st.spinner("Creating Topic Models ....."):
|
| 73 |
+
|
| 74 |
+
#load Bertopic
|
| 75 |
+
model=BERTopic.load(model_path)
|
| 76 |
+
#load dataset (used for creating scatter plot)
|
| 77 |
+
|
| 78 |
+
data_path = 'topicmodel/data.csv'
|
| 79 |
+
df = pd.read_csv(data_path)
|
| 80 |
+
#load embeddings reduced by UMAP for the points to be displayed by scatter plot
|
| 81 |
+
|
| 82 |
+
embeddings_path = 'topicmodel/embed.npy'
|
| 83 |
+
embeddings = np.load(embeddings_path)
|
| 84 |
+
#preprocess user inputs
|
| 85 |
+
request= preprocess(name, group, state, states_used, desc)
|
| 86 |
+
index=[]
|
| 87 |
+
#only select states that user wants to compare
|
| 88 |
+
for state_used in states_used:
|
| 89 |
+
index.extend(df.index[df['headquarters'].str.contains(state_used)].tolist())
|
| 90 |
+
select=embeddings[index]
|
| 91 |
+
|
| 92 |
+
#use bert topic to get probabilities
|
| 93 |
+
topic, prob_req= model.transform([request])
|
| 94 |
+
st.text("Modelling done! plotting results now...")
|
| 95 |
+
|
| 96 |
+
return topic, prob_req, select, df, index
|
| 97 |
|
| 98 |
def app():
|
| 99 |
|
|
|
|
| 118 |
#state= st.selectbox('Select state the company is based in', states)
|
| 119 |
#states_used = st.multiselect('Select states you want to analyse', states)
|
| 120 |
|
| 121 |
+
examples = [['Coursera','Education','California',['California','New York','Ohio'],'We are a social entrepreneurship company that partners with the top universities in the world to offer courses online for anyone to take, for free. We envision a future where the top universities are educating not only thousands of students, but millions. Our technology enables the best professors to teach tens or hundreds of thousands of students']]
|
| 122 |
|
| 123 |
if check_examples:
|
| 124 |
example = examples[0]
|
|
|
|
| 140 |
states_used = st.multiselect('Select states you want to analyse', states)
|
| 141 |
|
| 142 |
if(st.button("Analyse Competition")):
|
| 143 |
+
|
| 144 |
+
if companyname=="" or companydesc=="" or companygrp=="" or states_used==[]:
|
| 145 |
+
st.error("Some fields are empty!")
|
| 146 |
+
else:
|
| 147 |
+
model_path = 'topicmodel/my_model.pkl'
|
| 148 |
+
topic,prob_req,embed,df,index = load_topic_model(model_path, companyname, companygrp, state, states_used, companydesc)
|
| 149 |
+
visualizer(prob_req, embed, df, index, companyname)
|
| 150 |
|
| 151 |
|
| 152 |
if __name__ == "__main__":
|