chiichann commited on
Commit
53c2b9d
·
verified ·
1 Parent(s): 086ca05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -72
app.py CHANGED
@@ -6,24 +6,16 @@ from sklearn.preprocessing import StandardScaler
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
  import requests
9
- from io import StringIO
10
 
11
  # App title
12
  st.title("🛍️ Customer Segmentation Tool")
13
 
14
- # 🎯 Streamlit Tabs
15
  tab1, tab2, tab3, tab4 = st.tabs(["📖 About", "📊 Dataset Overview", "🧑‍🤝‍🧑 Customer Segmentation", "📥 Download Dataset"])
16
 
17
  # About Tab
18
  with tab1:
19
  st.write("""
20
- This app uses unsupervised learning techniques to segment customers based on their purchasing behavior.
21
- The dataset is uploaded by the user, containing online retail data.
22
- ### How It Works:
23
- - **Step 1**: Upload customer transaction data, including details like Quantity, UnitPrice, and CustomerID.
24
- - **Step 2**: Process the data by calculating the total spent and aggregating the information by customer.
25
- - **Step 3**: Apply **K-Means Clustering** to segment the customers into distinct groups.
26
- - **Step 4**: Visualize the customer segments with a scatter plot, and optionally download the segmented data.
27
  """)
28
 
29
  # File uploader in the Dataset Tab
@@ -32,70 +24,63 @@ with tab2:
32
 
33
  if uploaded_file is not None:
34
  try:
35
- # Check file type
36
  if uploaded_file.name.endswith('.csv'):
37
- # Read CSV file with error handling for malformed lines
38
  df = pd.read_csv(uploaded_file, encoding='ISO-8859-1', on_bad_lines='skip')
39
  elif uploaded_file.name.endswith('.xlsx'):
40
- # Read Excel file
41
  df = pd.read_excel(uploaded_file)
42
  else:
43
  st.error("Unsupported file format. Please upload a CSV or Excel file.")
44
  st.stop()
45
-
46
  st.write("### Dataset Overview")
47
  st.write(df.head())
48
  except Exception as e:
49
  st.error(f"Error loading dataset: {e}")
50
  st.stop()
51
 
52
- # Automatically detect possible columns
53
  st.write("### Columns detected in your dataset:")
54
  st.write(df.columns.tolist())
55
 
56
- # Allow the user to map columns
57
- customer_col = st.selectbox("Select Customer Column", df.columns.tolist(), index=df.columns.tolist().index("CustomerID") if "CustomerID" in df.columns else 0)
58
- quantity_col = st.selectbox("Select Quantity Column", df.columns.tolist(), index=df.columns.tolist().index("Quantity") if "Quantity" in df.columns else 0)
59
- unit_price_col = st.selectbox("Select Unit Price Column", df.columns.tolist(), index=df.columns.tolist().index("UnitPrice") if "UnitPrice" in df.columns else 0)
60
 
61
- # Check if the selected columns exist
62
  if customer_col not in df.columns or quantity_col not in df.columns or unit_price_col not in df.columns:
63
  st.error("One or more selected columns do not exist in the dataset. Please select valid columns.")
64
  st.stop()
65
 
66
- # Preprocess data
67
- df = df.dropna(subset=[customer_col]) # Remove rows without CustomerID
68
- df["TotalSpent"] = pd.to_numeric(df[quantity_col], errors='coerce') * pd.to_numeric(df[unit_price_col], errors='coerce') # Ensure numeric type
69
-
70
- # Ensure that TotalSpent column is not NaN
71
  df = df.dropna(subset=["TotalSpent"])
72
 
73
- # Aggregate data by Customer
74
  customer_data = df.groupby(customer_col).agg({
75
  "TotalSpent": "sum",
76
  quantity_col: "sum",
77
  unit_price_col: "mean"
78
- }).rename(columns={quantity_col: "NumTransactions", unit_price_col: "AvgUnitPrice"})
79
-
80
- # Debug: Check if 'NumTransactions' exists in the DataFrame
81
- st.write("### Available columns in the aggregated customer data:")
82
- st.write(customer_data.columns.tolist())
 
 
 
 
 
 
 
 
83
 
84
- # Standardize the data
85
  scaler = StandardScaler()
86
  customer_scaled = pd.DataFrame(scaler.fit_transform(customer_data), columns=customer_data.columns, index=customer_data.index)
87
 
88
  # Customer Segmentation Tab
89
  with tab3:
90
  if uploaded_file is not None:
91
- # User selects the number of clusters
92
  num_clusters = st.slider("Select Number of Clusters", min_value=2, max_value=10, value=3)
93
-
94
- # Apply K-Means clustering
95
  model = KMeans(n_clusters=num_clusters, random_state=42)
96
  customer_data["Cluster"] = model.fit_predict(customer_scaled)
97
 
98
- # Visualize the clusters
99
  st.write("### Clusters Visualization")
100
  fig, ax = plt.subplots()
101
  scatter = ax.scatter(customer_data["TotalSpent"], customer_data["NumTransactions"], c=customer_data["Cluster"], cmap='viridis')
@@ -105,11 +90,9 @@ with tab3:
105
  plt.colorbar(scatter, label="Cluster")
106
  st.pyplot(fig)
107
 
108
- # Show the segmented customer data
109
  st.write("### Customer Segments Data")
110
  st.write(customer_data.head())
111
 
112
- # Option to download the segmented data
113
  csv = customer_data.to_csv(index=True)
114
  st.download_button(
115
  label="Download Segmented Customer Data",
@@ -119,38 +102,3 @@ with tab3:
119
  )
120
  else:
121
  st.write("Please upload a dataset to start.")
122
-
123
- # Download Dataset Tab
124
- with tab4:
125
- st.write("""
126
- You can download the sample 'Online Retail' dataset to get started with customer segmentation tasks.
127
- Click the button below to download the dataset in CSV format.
128
- """)
129
-
130
- # Direct Google Drive link to the 'Online Retail' dataset (for direct download)
131
- dataset_url_online_retail = "https://drive.google.com/uc?id=1djBqO2sdHfy9DGZQXZu2Er8LUUXtp9Kr&export=download"
132
-
133
- # Direct Google Drive link to the new dataset (for direct download)
134
- dataset_url_new_file = "https://drive.google.com/uc?id=1PbGJSdcyDInsu-9Ua4iHzQh-YpVk_RqT&export=download"
135
-
136
- # Download the file from the URLs
137
- response_online_retail = requests.get(dataset_url_online_retail)
138
- file_data_online_retail = response_online_retail.text # Get the content as text
139
-
140
- response_new_file = requests.get(dataset_url_new_file)
141
- file_data_new_file = response_new_file.text # Get the content as text
142
-
143
- # Convert the CSV data into a CSV download for Streamlit
144
- st.download_button(
145
- label="Download Online Retail Dataset",
146
- data=file_data_online_retail,
147
- file_name="Online_Retail.csv",
148
- mime="text/csv"
149
- )
150
-
151
- st.download_button(
152
- label="Download New Dataset",
153
- data=file_data_new_file,
154
- file_name="New_Dataset.csv",
155
- mime="text/csv"
156
- )
 
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
  import requests
 
9
 
10
  # App title
11
  st.title("🛍️ Customer Segmentation Tool")
12
 
 
13
  tab1, tab2, tab3, tab4 = st.tabs(["📖 About", "📊 Dataset Overview", "🧑‍🤝‍🧑 Customer Segmentation", "📥 Download Dataset"])
14
 
15
  # About Tab
16
  with tab1:
17
  st.write("""
18
+ This app uses unsupervised learning techniques to segment customers based on their purchasing behavior.
 
 
 
 
 
 
19
  """)
20
 
21
  # File uploader in the Dataset Tab
 
24
 
25
  if uploaded_file is not None:
26
  try:
 
27
  if uploaded_file.name.endswith('.csv'):
 
28
  df = pd.read_csv(uploaded_file, encoding='ISO-8859-1', on_bad_lines='skip')
29
  elif uploaded_file.name.endswith('.xlsx'):
 
30
  df = pd.read_excel(uploaded_file)
31
  else:
32
  st.error("Unsupported file format. Please upload a CSV or Excel file.")
33
  st.stop()
34
+
35
  st.write("### Dataset Overview")
36
  st.write(df.head())
37
  except Exception as e:
38
  st.error(f"Error loading dataset: {e}")
39
  st.stop()
40
 
 
41
  st.write("### Columns detected in your dataset:")
42
  st.write(df.columns.tolist())
43
 
44
+ customer_col = st.selectbox("Select Customer Column", df.columns.tolist())
45
+ quantity_col = st.selectbox("Select Quantity Column", df.columns.tolist())
46
+ unit_price_col = st.selectbox("Select Unit Price Column", df.columns.tolist())
 
47
 
 
48
  if customer_col not in df.columns or quantity_col not in df.columns or unit_price_col not in df.columns:
49
  st.error("One or more selected columns do not exist in the dataset. Please select valid columns.")
50
  st.stop()
51
 
52
+ df = df.dropna(subset=[customer_col])
53
+ df["TotalSpent"] = pd.to_numeric(df[quantity_col], errors='coerce') * pd.to_numeric(df[unit_price_col], errors='coerce')
 
 
 
54
  df = df.dropna(subset=["TotalSpent"])
55
 
 
56
  customer_data = df.groupby(customer_col).agg({
57
  "TotalSpent": "sum",
58
  quantity_col: "sum",
59
  unit_price_col: "mean"
60
+ })
61
+
62
+ # Debugging: Check column names before renaming
63
+ st.write("### Columns before renaming:", customer_data.columns.tolist())
64
+
65
+ customer_data = customer_data.rename(columns={quantity_col: "NumTransactions", unit_price_col: "AvgUnitPrice"})
66
+
67
+ # Debugging: Check column names after renaming
68
+ st.write("### Columns after renaming:", customer_data.columns.tolist())
69
+
70
+ if "NumTransactions" not in customer_data.columns:
71
+ st.error("Error: 'NumTransactions' column is missing after processing. Please check column mapping.")
72
+ st.stop()
73
 
 
74
  scaler = StandardScaler()
75
  customer_scaled = pd.DataFrame(scaler.fit_transform(customer_data), columns=customer_data.columns, index=customer_data.index)
76
 
77
  # Customer Segmentation Tab
78
  with tab3:
79
  if uploaded_file is not None:
 
80
  num_clusters = st.slider("Select Number of Clusters", min_value=2, max_value=10, value=3)
 
 
81
  model = KMeans(n_clusters=num_clusters, random_state=42)
82
  customer_data["Cluster"] = model.fit_predict(customer_scaled)
83
 
 
84
  st.write("### Clusters Visualization")
85
  fig, ax = plt.subplots()
86
  scatter = ax.scatter(customer_data["TotalSpent"], customer_data["NumTransactions"], c=customer_data["Cluster"], cmap='viridis')
 
90
  plt.colorbar(scatter, label="Cluster")
91
  st.pyplot(fig)
92
 
 
93
  st.write("### Customer Segments Data")
94
  st.write(customer_data.head())
95
 
 
96
  csv = customer_data.to_csv(index=True)
97
  st.download_button(
98
  label="Download Segmented Customer Data",
 
102
  )
103
  else:
104
  st.write("Please upload a dataset to start.")