Spaces:
Sleeping
Sleeping
Commit ·
b2ce8cc
1
Parent(s): c4704e3
Show radar chart for Likert cluters; add tabs
Browse files- page_ai.py +99 -10
page_ai.py
CHANGED
|
@@ -5,23 +5,36 @@ from matplotlib.font_manager import FontProperties
|
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import seaborn as sns
|
| 7 |
import numpy as np
|
|
|
|
|
|
|
| 8 |
from fields.prod_feat_flat_fields import prod_feat_flat_fields
|
| 9 |
from fields.feature_translations import feature_translations
|
|
|
|
| 10 |
|
| 11 |
#@st.cache_data
|
| 12 |
def show(df):
|
| 13 |
# Load the Chinese font
|
| 14 |
chinese_font = FontProperties(fname='notosans.ttf', size=12)
|
| 15 |
st.title("AI Companion")
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def visualize_ai_roles(df, chinese_font):
|
|
@@ -35,7 +48,7 @@ def visualize_ai_roles(df, chinese_font):
|
|
| 35 |
# Plot the data
|
| 36 |
plt.figure(figsize=(10, 6))
|
| 37 |
ai_roles_data.plot(kind='bar', color='skyblue')
|
| 38 |
-
plt.title('
|
| 39 |
plt.xlabel('Roles', fontproperties=chinese_font)
|
| 40 |
plt.ylabel('Number of Responses', fontproperties=chinese_font)
|
| 41 |
plt.xticks(rotation=45, ha='right', fontproperties=chinese_font)
|
|
@@ -174,3 +187,79 @@ def plot_feature_preferences(clusters, font_prop):
|
|
| 174 |
|
| 175 |
# Streamlit uses st.pyplot() to display matplotlib charts
|
| 176 |
st.pyplot(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import seaborn as sns
|
| 7 |
import numpy as np
|
| 8 |
+
from sklearn.cluster import KMeans
|
| 9 |
+
from sklearn.decomposition import PCA
|
| 10 |
from fields.prod_feat_flat_fields import prod_feat_flat_fields
|
| 11 |
from fields.feature_translations import feature_translations
|
| 12 |
+
from fields.likert_flat_fields import likert_flat_fields
|
| 13 |
|
| 14 |
#@st.cache_data
|
| 15 |
def show(df):
|
| 16 |
# Load the Chinese font
|
| 17 |
chinese_font = FontProperties(fname='notosans.ttf', size=12)
|
| 18 |
st.title("AI Companion")
|
| 19 |
+
|
| 20 |
+
tab1, tab2 = st.tabs(["Likert-Based Clustering", "Feature-Based Clustering"])
|
| 21 |
+
|
| 22 |
+
with tab1:
|
| 23 |
+
st.write("AI-assistant feature choices per Likert-based Personas")
|
| 24 |
+
likert_cluster_and_visualize(df, likert_flat_fields, chinese_font)
|
| 25 |
+
|
| 26 |
+
with tab2:
|
| 27 |
+
st.write("Clustering students based on AI-assistant feature choices")
|
| 28 |
+
clusters = perform_kmodes_clustering(df, prod_feat_flat_fields)
|
| 29 |
+
st.markdown(
|
| 30 |
+
f"<h2 style='text-align: center;'>Feature Preferences (Overall)</h2>", unsafe_allow_html=True)
|
| 31 |
+
show_radar_chart(clusters, font_prop=chinese_font)
|
| 32 |
+
st.markdown(
|
| 33 |
+
f"<h2 style='text-align: center;'>Feature Preferences (By Cluster)</h2>", unsafe_allow_html=True)
|
| 34 |
+
plot_feature_preferences(clusters, font_prop=chinese_font)
|
| 35 |
+
st.markdown(
|
| 36 |
+
f"<h2 style='text-align: center;'>Preferred AI Roles (Overall)</h2>", unsafe_allow_html=True)
|
| 37 |
+
visualize_ai_roles(df, chinese_font)
|
| 38 |
|
| 39 |
|
| 40 |
def visualize_ai_roles(df, chinese_font):
|
|
|
|
| 48 |
# Plot the data
|
| 49 |
plt.figure(figsize=(10, 6))
|
| 50 |
ai_roles_data.plot(kind='bar', color='skyblue')
|
| 51 |
+
plt.title('Preferred AI Roles', fontproperties=chinese_font)
|
| 52 |
plt.xlabel('Roles', fontproperties=chinese_font)
|
| 53 |
plt.ylabel('Number of Responses', fontproperties=chinese_font)
|
| 54 |
plt.xticks(rotation=45, ha='right', fontproperties=chinese_font)
|
|
|
|
| 187 |
|
| 188 |
# Streamlit uses st.pyplot() to display matplotlib charts
|
| 189 |
st.pyplot(fig)
|
| 190 |
+
|
| 191 |
+
def likert_cluster_and_visualize(df, likert_flat_fields, chinese_font):
|
| 192 |
+
# Clean the DataFrame column names
|
| 193 |
+
df.columns = [col.strip() for col in df.columns]
|
| 194 |
+
|
| 195 |
+
# Also clean the likert_flat_fields if necessary
|
| 196 |
+
likert_flat_fields = [field.strip() for field in likert_flat_fields]
|
| 197 |
+
|
| 198 |
+
# Prepare the likert data, dropping any rows with missing values
|
| 199 |
+
df_likert_data = df[likert_flat_fields].dropna()
|
| 200 |
+
|
| 201 |
+
# Perform k-means clustering
|
| 202 |
+
kmeans = KMeans(n_clusters=3, n_init=10, random_state=42).fit(df_likert_data)
|
| 203 |
+
df_likert_data['Cluster'] = kmeans.labels_
|
| 204 |
+
|
| 205 |
+
# Concatenate the cluster labels with the original data
|
| 206 |
+
df_clustered = pd.concat([df, df_likert_data['Cluster']], axis=1)
|
| 207 |
+
|
| 208 |
+
# Aggregate the product preference data for each cluster
|
| 209 |
+
cluster_preferences = []
|
| 210 |
+
for i in range(3):
|
| 211 |
+
cluster_data = df_clustered[df_clustered['Cluster'] == i]
|
| 212 |
+
cluster_preferences.append(cluster_data[prod_feat_flat_fields].mean())
|
| 213 |
+
|
| 214 |
+
# Radar Chart Plotting
|
| 215 |
+
df_dict = {
|
| 216 |
+
'Eco-Friendly': cluster_preferences[0],
|
| 217 |
+
'Moderate': cluster_preferences[1],
|
| 218 |
+
'Frugal': cluster_preferences[2]
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
feature_translations_dict = dict(zip(prod_feat_flat_fields, feature_translations))
|
| 222 |
+
persona_averages = [df_dict[key].tolist() for key in df_dict]
|
| 223 |
+
|
| 224 |
+
# Append the first value at the end of each list for the radar chart
|
| 225 |
+
for averages in persona_averages:
|
| 226 |
+
averages += averages[:1]
|
| 227 |
+
|
| 228 |
+
# Prepare the English labels for plotting
|
| 229 |
+
english_feature_labels = list(feature_translations)
|
| 230 |
+
english_feature_labels += [english_feature_labels[0]] # Repeat the first label to close the loop
|
| 231 |
+
|
| 232 |
+
# Number of variables we're plotting
|
| 233 |
+
num_vars = len(english_feature_labels)
|
| 234 |
+
|
| 235 |
+
# Split the circle into even parts and save the angles
|
| 236 |
+
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
| 237 |
+
angles += angles[:1] # Complete the loop
|
| 238 |
+
|
| 239 |
+
# Set up the font properties for using a custom font
|
| 240 |
+
fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))
|
| 241 |
+
fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
|
| 242 |
+
|
| 243 |
+
# Draw one axe per variable and add labels
|
| 244 |
+
plt.xticks(angles[:-1], english_feature_labels, color='grey', size=12, fontproperties=chinese_font)
|
| 245 |
+
|
| 246 |
+
# Draw ylabels
|
| 247 |
+
ax.set_rlabel_position(0)
|
| 248 |
+
plt.yticks([0.2, 0.4, 0.6, 0.8, 1], ["0.2", "0.4", "0.6", "0.8", "1"], color="grey", size=7)
|
| 249 |
+
plt.ylim(0, 1)
|
| 250 |
+
|
| 251 |
+
# Plot data and fill with color
|
| 252 |
+
for label, data in zip(df_dict.keys(), persona_averages):
|
| 253 |
+
data += data[:1] # Complete the loop
|
| 254 |
+
ax.plot(angles, data, label=label, linewidth=1, linestyle='solid')
|
| 255 |
+
ax.fill(angles, data, alpha=0.25)
|
| 256 |
+
|
| 257 |
+
# Add legend
|
| 258 |
+
plt.legend(title='Personas')
|
| 259 |
+
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
|
| 260 |
+
|
| 261 |
+
# Add a title
|
| 262 |
+
plt.title('Product Feature Preferences by Persona', size=20, color='grey', y=1.1, fontproperties=chinese_font)
|
| 263 |
+
|
| 264 |
+
# Display the radar chart
|
| 265 |
+
st.pyplot(fig)
|