Spaces:
Runtime error
Runtime error
| from json import load | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import re | |
| import string | |
| from nltk.stem import WordNetLemmatizer | |
| import umap | |
| import plotly.graph_objects as go | |
| from plotly import tools | |
| import plotly.offline as py | |
| import plotly.express as px | |
| from nltk.corpus import stopwords | |
| import nltk | |
| nltk.download('stopwords') | |
| nltk.download('wordnet') | |
| from bertopic import BERTopic | |
| import pickle | |
| import os | |
| def read_markdown(path,parent='about/'): | |
| with open(os.path.join(parent,path)) as f: | |
| return f.read() | |
| def visualizer(prob_req, embed, df, index, company_name): | |
| with st.spinner("Visualizing the results !!!"): | |
| fname = 'topicmodel/saving_example.sav' | |
| reducer= pickle.load((open(fname, 'rb'))) #load the umap dimensionality reduction model trained on rest of probablities | |
| embed_req= reducer.transform(prob_req) | |
| #add scatter plot for all embeddings from our dataset | |
| fig1 = px.scatter( | |
| embed, x=0, y=1, | |
| color=df.iloc[index]['headquarters'], labels={'color': 'states'}, hover_name= df.iloc[index]['company_name'] + " with industry group: "+ df.iloc[index]['industry_groups']) | |
| #add the data for users request and display | |
| fig1.add_trace( | |
| go.Scatter( | |
| x=embed_req[:,0], | |
| y=embed_req[:,1], | |
| mode='markers', | |
| marker_symbol="hexagon2", marker_size=15, | |
| showlegend=True, name= company_name, hovertext= company_name)) | |
| st.plotly_chart(fig1) | |
| def clean_text(text): | |
| """util function to clean the text""" | |
| text = str(text).lower() | |
| text = re.sub('https?://\S+|www\.\S+', '', text) | |
| text = re.sub('<.,*?>+', '', text) | |
| text = re.sub('[%s]' % re.escape(string.punctuation), '', text) | |
| return text | |
| def preprocess(name, group, state, states_used, desc): | |
| desc = desc.replace(name,'') | |
| cat = "".join(cat for cat in group.split(",")) | |
| cleaned= desc + " " + cat | |
| stop_words = stopwords.words('english') | |
| lemmatizer = WordNetLemmatizer() | |
| text = clean_text(cleaned) | |
| text = ' '.join(w for w in text.split(' ') if w not in stop_words) | |
| text = ' '.join(lemmatizer.lemmatize(w) for w in text.split(' ')) | |
| return text | |
| def load_topic_model(model_path, name, group, state, states_used, desc): | |
| with st.spinner("Creating Topic Models ....."): | |
| #load Bertopic | |
| model=BERTopic.load(model_path) | |
| #load dataset (used for creating scatter plot) | |
| data_path = 'topicmodel/data.csv' | |
| df = pd.read_csv(data_path) | |
| #load embeddings reduced by UMAP for the points to be displayed by scatter plot | |
| embeddings_path = 'topicmodel/embed.npy' | |
| embeddings = np.load(embeddings_path) | |
| #preprocess user inputs | |
| request= preprocess(name, group, state, states_used, desc) | |
| index=[] | |
| #only select states that user wants to compare | |
| for state_used in states_used: | |
| index.extend(df.index[df['headquarters'].str.contains(state_used)].tolist()) | |
| select=embeddings[index] | |
| #use bert topic to get probabilities | |
| topic, prob_req= model.transform([request]) | |
| #st.text("Modelling done! plotting results now...") | |
| return topic, prob_req, select, df, index | |
| def app(): | |
| st.title("Competitive Analysis of Companies ") | |
| check_examples = st.sidebar.checkbox("Try Examples!") | |
| st.markdown(read_markdown("userguide.md")) | |
| states= ['Georgia', 'California', 'Texas', 'Tennessee', 'Massachusetts', | |
| 'New York', 'Ohio', 'Delaware', 'Florida', 'Washington', | |
| 'Connecticut', 'Colorado', 'South Carolina', 'New Jersey', | |
| 'Michigan', 'Maryland', 'Pennsylvania', 'Virginia', 'Vermont', | |
| 'Minnesota', 'Illinois', 'North Carolina', 'Montana', 'Kentucky', | |
| 'Oregon', 'Iowa', 'District of Columbia', 'Arizona', 'Wisconsin', | |
| 'Louisiana', 'Idaho', 'Utah', 'Nevada', 'Nebraska', 'New Mexico', | |
| 'Missouri', 'Kansas', 'New Hampshire', 'Wyoming', 'Arkansas', | |
| 'Indiana', 'North Dakota', 'Hawaii', 'Alabama', 'Maine', | |
| 'Rhode Island', 'Mississippi', 'Alaska', 'Oklahoma', | |
| 'Washington DC', 'Giorgia'] | |
| #state= st.selectbox('Select state the company is based in', states) | |
| #states_used = st.multiselect('Select states you want to analyse', states) | |
| examples = [['Coursera','Education','California',['California','New York','Ohio'],'We are a social entrepreneurship company that partners with the top universities in the world to offer courses online for anyone to take, for free. We envision a future where the top universities are educating not only thousands of students, but millions. Our technology enables the best professors to teach tens or hundreds of thousands of students']] | |
| if check_examples: | |
| example = examples[0] | |
| companyname = st.text_input('Input company name here:', example[0]) | |
| companygrp = st.text_input('Input industry group here:', example[1]) | |
| companydesc = st.text_input("Input company description: (can be found in the company's linkedin page)", example[4]) | |
| state = st.selectbox('Select state the company is based in',states,index = 1) | |
| states_used = st.multiselect('Select states you want to analyse', states,example[3]) | |
| #model_path = 'topicmodel/my_model.pkl' | |
| #topic,prob_req,embed,df,index = load_topic_model(model_path,example[0],example[1],example[2],example[3],example[4]) | |
| #visualizer(prob_req,embed,df,index,company_name) | |
| else: | |
| companyname = st.text_input('Input company name here:', value="") | |
| companygrp = st.text_input('Input industry group here:', value="") | |
| companydesc = st.text_input("Input company description: (can be found in the company's linkedin page)", value="") | |
| state= st.selectbox('Select state the company is based in', states) | |
| states_used = st.multiselect('Select states you want to analyse', states) | |
| if(st.button("Analyse Competition")): | |
| if companyname=="" or companydesc=="" or companygrp=="" or states_used==[]: | |
| st.error("Some fields are empty!") | |
| else: | |
| model_path = 'topicmodel/my_model.pkl' | |
| topic,prob_req,embed,df,index = load_topic_model(model_path, companyname, companygrp, state, states_used, companydesc) | |
| visualizer(prob_req, embed, df, index, companyname) | |
| if __name__ == "__main__": | |
| app() | |