mlkorra commited on
Commit
7dfe9b5
·
1 Parent(s): 96a99b3

Update App

Browse files
Files changed (1) hide show
  1. app.py +52 -49
app.py CHANGED
@@ -23,23 +23,25 @@ import os
23
 
24
  def visualizer(prob_req, embed, df, index, company_name):
25
 
26
- fname = 'topicmodel/saving_example.sav'
27
- reducer= pickle.load((open(fname, 'rb'))) #load the umap dimensionality reduction model trained on rest of probablities
28
- embed_req= reducer.transform(prob_req)
29
-
30
- #add scatter plot for all embeddings from our dataset
31
- fig1 = px.scatter(
32
- embed, x=0, y=1,
33
- color=df.iloc[index]['headquarters'], labels={'color': 'states'}, hover_name= df.iloc[index]['company_name'] + " with industry group: "+ df.iloc[index]['industry_groups'])
34
- #add the data for users request and display
35
- fig1.add_trace(
36
- go.Scatter(
37
- x=embed_req[:,0],
38
- y=embed_req[:,1],
39
- mode='markers',
40
- marker_symbol="hexagon2", marker_size=15,
41
- showlegend=True, name= company_name, hovertext= company_name))
42
- st.plotly_chart(fig1)
 
 
43
 
44
  def clean_text(text):
45
 
@@ -67,30 +69,31 @@ def preprocess(name, group, state, states_used, desc):
67
  @st.cache(persist=True,suppress_st_warning=True)
68
  def load_topic_model(model_path, name, group, state, states_used, desc):
69
 
70
-
71
- #load Bertopic
72
- model=BERTopic.load(model_path)
73
- #load dataset (used for creating scatter plot)
74
-
75
- data_path = 'topicmodel/data.csv'
76
- df = pd.read_csv(data_path)
77
- #load embeddings reduced by UMAP for the points to be displayed by scatter plot
78
-
79
- embeddings_path = 'topicmodel/embed.npy'
80
- embeddings = np.load(embeddings_path)
81
- #preprocess user inputs
82
- request= preprocess(name, group, state, states_used, desc)
83
- index=[]
84
- #only select states that user wants to compare
85
- for state_used in states_used:
86
- index.extend(df.index[df['headquarters'].str.contains(state_used)].tolist())
87
- select=embeddings[index]
88
-
89
- #use bert topic to get probabilities
90
- topic, prob_req= model.transform([request])
91
- st.text("Modelling done! plotting results now...")
92
-
93
- return topic, prob_req, select, df, index
 
94
 
95
  def app():
96
 
@@ -115,7 +118,7 @@ def app():
115
  #state= st.selectbox('Select state the company is based in', states)
116
  #states_used = st.multiselect('Select states you want to analyse', states)
117
 
118
- examples = [['Coursera','Education','California',['California','Washington','Ohio'],'We are a social entrepreneurship company that partners with the top universities in the world to offer courses online for anyone to take, for free. We envision a future where the top universities are educating not only thousands of students, but millions. Our technology enables the best professors to teach tens or hundreds of thousands of students']]
119
 
120
  if check_examples:
121
  example = examples[0]
@@ -137,13 +140,13 @@ def app():
137
  states_used = st.multiselect('Select states you want to analyse', states)
138
 
139
  if(st.button("Analyse Competition")):
140
- if companyname=="" or companydesc=="" or companygrp=="" or states_used==[]:
141
- st.error("Some fields are empty!")
142
- else:
143
-
144
- model_path = 'topicmodel/my_model.pkl'
145
- topic,prob_req,embed,df,index = load_topic_model(model_path, companyname, companygrp, state, states_used, companydesc)
146
- visualizer(prob_req, embed, df, index, companyname)
147
 
148
 
149
  if __name__ == "__main__":
 
23
 
24
  def visualizer(prob_req, embed, df, index, company_name):
25
 
26
+ with st.spinner("Visualizing the results !!!"):
27
+
28
+ fname = 'topicmodel/saving_example.sav'
29
+ reducer= pickle.load((open(fname, 'rb'))) #load the umap dimensionality reduction model trained on rest of probablities
30
+ embed_req= reducer.transform(prob_req)
31
+
32
+ #add scatter plot for all embeddings from our dataset
33
+ fig1 = px.scatter(
34
+ embed, x=0, y=1,
35
+ color=df.iloc[index]['headquarters'], labels={'color': 'states'}, hover_name= df.iloc[index]['company_name'] + " with industry group: "+ df.iloc[index]['industry_groups'])
36
+ #add the data for users request and display
37
+ fig1.add_trace(
38
+ go.Scatter(
39
+ x=embed_req[:,0],
40
+ y=embed_req[:,1],
41
+ mode='markers',
42
+ marker_symbol="hexagon2", marker_size=15,
43
+ showlegend=True, name= company_name, hovertext= company_name))
44
+ st.plotly_chart(fig1)
45
 
46
  def clean_text(text):
47
 
 
69
  @st.cache(persist=True,suppress_st_warning=True)
70
  def load_topic_model(model_path, name, group, state, states_used, desc):
71
 
72
+ with st.spinner("Creating Topic Models ....."):
73
+
74
+ #load Bertopic
75
+ model=BERTopic.load(model_path)
76
+ #load dataset (used for creating scatter plot)
77
+
78
+ data_path = 'topicmodel/data.csv'
79
+ df = pd.read_csv(data_path)
80
+ #load embeddings reduced by UMAP for the points to be displayed by scatter plot
81
+
82
+ embeddings_path = 'topicmodel/embed.npy'
83
+ embeddings = np.load(embeddings_path)
84
+ #preprocess user inputs
85
+ request= preprocess(name, group, state, states_used, desc)
86
+ index=[]
87
+ #only select states that user wants to compare
88
+ for state_used in states_used:
89
+ index.extend(df.index[df['headquarters'].str.contains(state_used)].tolist())
90
+ select=embeddings[index]
91
+
92
+ #use bert topic to get probabilities
93
+ topic, prob_req= model.transform([request])
94
+ st.text("Modelling done! plotting results now...")
95
+
96
+ return topic, prob_req, select, df, index
97
 
98
  def app():
99
 
 
118
  #state= st.selectbox('Select state the company is based in', states)
119
  #states_used = st.multiselect('Select states you want to analyse', states)
120
 
121
+ examples = [['Coursera','Education','California',['California','New York','Ohio'],'We are a social entrepreneurship company that partners with the top universities in the world to offer courses online for anyone to take, for free. We envision a future where the top universities are educating not only thousands of students, but millions. Our technology enables the best professors to teach tens or hundreds of thousands of students']]
122
 
123
  if check_examples:
124
  example = examples[0]
 
140
  states_used = st.multiselect('Select states you want to analyse', states)
141
 
142
  if(st.button("Analyse Competition")):
143
+
144
+ if companyname=="" or companydesc=="" or companygrp=="" or states_used==[]:
145
+ st.error("Some fields are empty!")
146
+ else:
147
+ model_path = 'topicmodel/my_model.pkl'
148
+ topic,prob_req,embed,df,index = load_topic_model(model_path, companyname, companygrp, state, states_used, companydesc)
149
+ visualizer(prob_req, embed, df, index, companyname)
150
 
151
 
152
  if __name__ == "__main__":