File size: 6,596 Bytes
96a99b3
0ffa809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957ad79
 
 
 
0ffa809
 
7dfe9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ffa809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957ad79
0ffa809
 
7dfe9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f263f40
7dfe9b5
 
0ffa809
 
 
 
96a99b3
 
957ad79
fc5624a
 
0ffa809
96a99b3
 
 
 
 
 
 
 
 
 
 
 
 
7dfe9b5
96a99b3
 
 
fc5624a
 
 
 
 
96a99b3
 
 
 
 
 
fc5624a
 
 
 
 
0ffa809
fc5624a
7dfe9b5
 
fc5624a
7dfe9b5
 
 
 
ba27bfb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from json import load
import streamlit as st
import pandas as pd
import numpy as np
import re
import string

from nltk.stem import WordNetLemmatizer
import umap

import plotly.graph_objects as go
from plotly import tools
import plotly.offline as py
import plotly.express as px

from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
from bertopic import BERTopic
import pickle
import os

def read_markdown(path,parent='about/'):
    with open(os.path.join(parent,path)) as f:
        return f.read()

def visualizer(prob_req, embed, df, index, company_name):
    
    with st.spinner("Visualizing the results !!!"):

        fname = 'topicmodel/saving_example.sav'
        reducer= pickle.load((open(fname, 'rb'))) #load the umap dimensionality reduction model trained on rest of probablities
        embed_req= reducer.transform(prob_req)
        
        #add scatter plot for all embeddings from our dataset
        fig1 = px.scatter(
        embed, x=0, y=1,
        color=df.iloc[index]['headquarters'], labels={'color': 'states'}, hover_name= df.iloc[index]['company_name'] + " with industry group: "+  df.iloc[index]['industry_groups'])
        #add the data for users request and display
        fig1.add_trace(
        go.Scatter(
            x=embed_req[:,0],
            y=embed_req[:,1],
            mode='markers',
            marker_symbol="hexagon2", marker_size=15,
            showlegend=True, name= company_name, hovertext= company_name))
        st.plotly_chart(fig1)

def clean_text(text):

  """util function to clean the text"""

  text = str(text).lower()
  text = re.sub('https?://\S+|www\.\S+', '', text)
  text = re.sub('<.,*?>+', '', text)
  text = re.sub('[%s]' % re.escape(string.punctuation), '', text)

  return text

def preprocess(name, group, state, states_used, desc):
    desc = desc.replace(name,'')
    cat = "".join(cat for cat in group.split(","))
    cleaned= desc + " " + cat
    
    stop_words = stopwords.words('english')
    lemmatizer = WordNetLemmatizer()
    text = clean_text(cleaned)
    text = ' '.join(w for w in text.split(' ') if w not in stop_words)
    text = ' '.join(lemmatizer.lemmatize(w) for w in text.split(' '))
    return text
    
@st.cache(persist=True,suppress_st_warning=True,show_spinner=False)
def load_topic_model(model_path, name, group, state, states_used, desc):

    with st.spinner("Creating Topic Models ....."):

        #load Bertopic 
        model=BERTopic.load(model_path)
        #load dataset (used for creating scatter plot)
        
        data_path = 'topicmodel/data.csv'
        df = pd.read_csv(data_path)
        #load embeddings reduced by UMAP for the points to be displayed by scatter plot
        
        embeddings_path = 'topicmodel/embed.npy'
        embeddings = np.load(embeddings_path)
        #preprocess user inputs
        request= preprocess(name, group, state, states_used, desc)
        index=[]
        #only select states that user wants to compare
        for state_used in states_used:
            index.extend(df.index[df['headquarters'].str.contains(state_used)].tolist()) 
        select=embeddings[index]
        
        #use bert topic to get probabilities
        topic, prob_req= model.transform([request])
        #st.text("Modelling done! plotting results now...")
        
        return topic, prob_req, select, df, index

def app():

    st.title("Competitive Analysis of Companies ")
    
    check_examples = st.sidebar.checkbox("Try Examples!")
    
    st.markdown(read_markdown("userguide.md"))
   
    states= ['Georgia', 'California', 'Texas', 'Tennessee', 'Massachusetts',
        'New York', 'Ohio', 'Delaware', 'Florida', 'Washington',
        'Connecticut', 'Colorado', 'South Carolina', 'New Jersey',
        'Michigan', 'Maryland', 'Pennsylvania', 'Virginia', 'Vermont',
        'Minnesota', 'Illinois', 'North Carolina', 'Montana', 'Kentucky',
        'Oregon', 'Iowa', 'District of Columbia', 'Arizona', 'Wisconsin',
        'Louisiana', 'Idaho', 'Utah', 'Nevada', 'Nebraska', 'New Mexico',
        'Missouri', 'Kansas', 'New Hampshire', 'Wyoming', 'Arkansas',
        'Indiana', 'North Dakota', 'Hawaii', 'Alabama', 'Maine',
        'Rhode Island', 'Mississippi', 'Alaska', 'Oklahoma',
        'Washington DC', 'Giorgia']
    #state= st.selectbox('Select state the company is based in', states)
    #states_used = st.multiselect('Select states you want to analyse', states)

    examples = [['Coursera','Education','California',['California','New York','Ohio'],'We are a social entrepreneurship company that partners with the top universities in the world to offer courses online for anyone to take, for free. We envision a future where the top universities are educating not only thousands of students, but millions. Our technology enables the best professors to teach tens or hundreds of thousands of students']]

    if check_examples:
        example = examples[0]
        companyname = st.text_input('Input company name here:', example[0])
        companygrp = st.text_input('Input industry group here:', example[1])
        companydesc = st.text_input("Input company description: (can be found in the company's linkedin page)", example[4])
        state = st.selectbox('Select state the company is based in',states,index = 1)
        states_used = st.multiselect('Select states you want to analyse', states,example[3])
        #model_path = 'topicmodel/my_model.pkl'
        #topic,prob_req,embed,df,index = load_topic_model(model_path,example[0],example[1],example[2],example[3],example[4])
        #visualizer(prob_req,embed,df,index,company_name)

    else:

        companyname = st.text_input('Input company name here:', value="")
        companygrp = st.text_input('Input industry group here:', value="")
        companydesc = st.text_input("Input company description: (can be found in the company's linkedin page)", value="")
        state= st.selectbox('Select state the company is based in', states)
        states_used = st.multiselect('Select states you want to analyse', states)

    if(st.button("Analyse Competition")):
        
        if companyname=="" or companydesc=="" or companygrp=="" or states_used==[]:
            st.error("Some fields are empty!")
        else:
            model_path = 'topicmodel/my_model.pkl'
            topic,prob_req,embed,df,index = load_topic_model(model_path, companyname, companygrp, state, states_used, companydesc)
            visualizer(prob_req, embed, df, index, companyname)
                
                
if __name__ == "__main__":
   app()