krishaamer commited on
Commit
b2ce8cc
·
1 Parent(s): c4704e3

Show radar chart for Likert cluters; add tabs

Browse files
Files changed (1) hide show
  1. page_ai.py +99 -10
page_ai.py CHANGED
@@ -5,23 +5,36 @@ from matplotlib.font_manager import FontProperties
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  import numpy as np
 
 
8
  from fields.prod_feat_flat_fields import prod_feat_flat_fields
9
  from fields.feature_translations import feature_translations
 
10
 
11
  #@st.cache_data
12
  def show(df):
13
  # Load the Chinese font
14
  chinese_font = FontProperties(fname='notosans.ttf', size=12)
15
  st.title("AI Companion")
16
- st.write("Clustering students based on AI-assistant feature choices")
17
- clusters = perform_kmodes_clustering(df, prod_feat_flat_fields)
18
- st.markdown(
19
- f"<h2 style='text-align: center;'>Feature Preferences</h2>", unsafe_allow_html=True)
20
- show_radar_chart(clusters, font_prop=chinese_font)
21
- plot_feature_preferences(clusters, font_prop=chinese_font)
22
- st.markdown(
23
- f"<h2 style='text-align: center;'>Preferred AI Roles</h2>", unsafe_allow_html=True)
24
- visualize_ai_roles(df, chinese_font)
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def visualize_ai_roles(df, chinese_font):
@@ -35,7 +48,7 @@ def visualize_ai_roles(df, chinese_font):
35
  # Plot the data
36
  plt.figure(figsize=(10, 6))
37
  ai_roles_data.plot(kind='bar', color='skyblue')
38
- plt.title('Desired AI Roles', fontproperties=chinese_font)
39
  plt.xlabel('Roles', fontproperties=chinese_font)
40
  plt.ylabel('Number of Responses', fontproperties=chinese_font)
41
  plt.xticks(rotation=45, ha='right', fontproperties=chinese_font)
@@ -174,3 +187,79 @@ def plot_feature_preferences(clusters, font_prop):
174
 
175
  # Streamlit uses st.pyplot() to display matplotlib charts
176
  st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  import numpy as np
8
+ from sklearn.cluster import KMeans
9
+ from sklearn.decomposition import PCA
10
  from fields.prod_feat_flat_fields import prod_feat_flat_fields
11
  from fields.feature_translations import feature_translations
12
+ from fields.likert_flat_fields import likert_flat_fields
13
 
14
  #@st.cache_data
15
  def show(df):
16
  # Load the Chinese font
17
  chinese_font = FontProperties(fname='notosans.ttf', size=12)
18
  st.title("AI Companion")
19
+
20
+ tab1, tab2 = st.tabs(["Likert-Based Clustering", "Feature-Based Clustering"])
21
+
22
+ with tab1:
23
+ st.write("AI-assistant feature choices per Likert-based Personas")
24
+ likert_cluster_and_visualize(df, likert_flat_fields, chinese_font)
25
+
26
+ with tab2:
27
+ st.write("Clustering students based on AI-assistant feature choices")
28
+ clusters = perform_kmodes_clustering(df, prod_feat_flat_fields)
29
+ st.markdown(
30
+ f"<h2 style='text-align: center;'>Feature Preferences (Overall)</h2>", unsafe_allow_html=True)
31
+ show_radar_chart(clusters, font_prop=chinese_font)
32
+ st.markdown(
33
+ f"<h2 style='text-align: center;'>Feature Preferences (By Cluster)</h2>", unsafe_allow_html=True)
34
+ plot_feature_preferences(clusters, font_prop=chinese_font)
35
+ st.markdown(
36
+ f"<h2 style='text-align: center;'>Preferred AI Roles (Overall)</h2>", unsafe_allow_html=True)
37
+ visualize_ai_roles(df, chinese_font)
38
 
39
 
40
  def visualize_ai_roles(df, chinese_font):
 
48
  # Plot the data
49
  plt.figure(figsize=(10, 6))
50
  ai_roles_data.plot(kind='bar', color='skyblue')
51
+ plt.title('Preferred AI Roles', fontproperties=chinese_font)
52
  plt.xlabel('Roles', fontproperties=chinese_font)
53
  plt.ylabel('Number of Responses', fontproperties=chinese_font)
54
  plt.xticks(rotation=45, ha='right', fontproperties=chinese_font)
 
187
 
188
  # Streamlit uses st.pyplot() to display matplotlib charts
189
  st.pyplot(fig)
190
+
191
+ def likert_cluster_and_visualize(df, likert_flat_fields, chinese_font):
192
+ # Clean the DataFrame column names
193
+ df.columns = [col.strip() for col in df.columns]
194
+
195
+ # Also clean the likert_flat_fields if necessary
196
+ likert_flat_fields = [field.strip() for field in likert_flat_fields]
197
+
198
+ # Prepare the likert data, dropping any rows with missing values
199
+ df_likert_data = df[likert_flat_fields].dropna()
200
+
201
+ # Perform k-means clustering
202
+ kmeans = KMeans(n_clusters=3, n_init=10, random_state=42).fit(df_likert_data)
203
+ df_likert_data['Cluster'] = kmeans.labels_
204
+
205
+ # Concatenate the cluster labels with the original data
206
+ df_clustered = pd.concat([df, df_likert_data['Cluster']], axis=1)
207
+
208
+ # Aggregate the product preference data for each cluster
209
+ cluster_preferences = []
210
+ for i in range(3):
211
+ cluster_data = df_clustered[df_clustered['Cluster'] == i]
212
+ cluster_preferences.append(cluster_data[prod_feat_flat_fields].mean())
213
+
214
+ # Radar Chart Plotting
215
+ df_dict = {
216
+ 'Eco-Friendly': cluster_preferences[0],
217
+ 'Moderate': cluster_preferences[1],
218
+ 'Frugal': cluster_preferences[2]
219
+ }
220
+
221
+ feature_translations_dict = dict(zip(prod_feat_flat_fields, feature_translations))
222
+ persona_averages = [df_dict[key].tolist() for key in df_dict]
223
+
224
+ # Append the first value at the end of each list for the radar chart
225
+ for averages in persona_averages:
226
+ averages += averages[:1]
227
+
228
+ # Prepare the English labels for plotting
229
+ english_feature_labels = list(feature_translations)
230
+ english_feature_labels += [english_feature_labels[0]] # Repeat the first label to close the loop
231
+
232
+ # Number of variables we're plotting
233
+ num_vars = len(english_feature_labels)
234
+
235
+ # Split the circle into even parts and save the angles
236
+ angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
237
+ angles += angles[:1] # Complete the loop
238
+
239
+ # Set up the font properties for using a custom font
240
+ fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))
241
+ fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
242
+
243
+ # Draw one axe per variable and add labels
244
+ plt.xticks(angles[:-1], english_feature_labels, color='grey', size=12, fontproperties=chinese_font)
245
+
246
+ # Draw ylabels
247
+ ax.set_rlabel_position(0)
248
+ plt.yticks([0.2, 0.4, 0.6, 0.8, 1], ["0.2", "0.4", "0.6", "0.8", "1"], color="grey", size=7)
249
+ plt.ylim(0, 1)
250
+
251
+ # Plot data and fill with color
252
+ for label, data in zip(df_dict.keys(), persona_averages):
253
+ data += data[:1] # Complete the loop
254
+ ax.plot(angles, data, label=label, linewidth=1, linestyle='solid')
255
+ ax.fill(angles, data, alpha=0.25)
256
+
257
+ # Add legend
258
+ plt.legend(title='Personas')
259
+ plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
260
+
261
+ # Add a title
262
+ plt.title('Product Feature Preferences by Persona', size=20, color='grey', y=1.1, fontproperties=chinese_font)
263
+
264
+ # Display the radar chart
265
+ st.pyplot(fig)