suvradeepp commited on
Commit
a5f0433
·
verified ·
1 Parent(s): 7e315ff

Create voting_classifier_viz.py

Browse files
Files changed (1) hide show
  1. voting_classifier_viz.py +74 -0
voting_classifier_viz.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import matplotlib.pyplot as plt
3
+ from sthelper import StHelper
4
+ import data_helper
5
+
6
+ # Import all datasets
7
+ concentric, linear, outlier, spiral, ushape, xor = data_helper.load_dataset()
8
+
9
+ # Configure matplotlib styling
10
+ plt.style.use('seaborn-v0_8-bright')
11
+
12
+ # Dataset selection dropdown
13
+ st.sidebar.markdown("# Voting Classifier")
14
+ dataset = st.sidebar.selectbox(
15
+ "Dataset",
16
+ ("U-Shaped", "Linearly Separable", "Outlier", "Two Spirals", "Concentric Circles", "XOR")
17
+ )
18
+
19
+ # Estimator multi-select
20
+ estimators = st.sidebar.multiselect(
21
+ 'Estimators',
22
+ [
23
+ 'KNN',
24
+ 'Logistic Regression',
25
+ 'Gaussian Naive Bayes',
26
+ 'SVM',
27
+ 'Random Forest'
28
+ ]
29
+ )
30
+
31
+ # Voting type radio button
32
+ voting_type = st.sidebar.radio(
33
+ "Voting Type",
34
+ (
35
+ 'hard',
36
+ 'soft',
37
+ )
38
+ )
39
+
40
+ st.header(dataset)
41
+ fig, ax = plt.subplots()
42
+
43
+ # Plot initial graph
44
+ df = data_helper.load_initial_graph(dataset, ax)
45
+ orig = st.pyplot(fig)
46
+
47
+ # Extract X and Y
48
+ X = df.iloc[:, :2].values
49
+ y = df.iloc[:, -1].values
50
+
51
+ # Create sthelper object
52
+ sthelper = StHelper(X, y)
53
+
54
+ # On button click
55
+ if st.sidebar.button("Run Algorithm"):
56
+ algos = sthelper.create_base_estimators(estimators, voting_type)
57
+ voting_clf, voting_clf_accuracy = sthelper.train_voting_classifier(algos, voting_type)
58
+ sthelper.draw_main_graph(voting_clf, ax)
59
+ orig.pyplot(fig)
60
+ figs = sthelper.plot_other_graphs(algos)
61
+
62
+ # Plot accuracies
63
+ st.sidebar.header("Classification Metrics")
64
+ st.sidebar.text("Voting Classifier accuracy: " + str(voting_clf_accuracy))
65
+ accuracies = sthelper.calculate_base_model_accuracy(algos)
66
+ for i in range(len(accuracies)):
67
+ st.sidebar.text("Accuracy for Model " + str(i + 1) + " - " + str(accuracies[i]))
68
+
69
+ counter = 0
70
+ for i in st.columns(len(figs)):
71
+ with i:
72
+ st.pyplot(figs[counter])
73
+ st.text(counter)
74
+ counter += 1