chiichann commited on
Commit
c68d93e
·
verified ·
1 Parent(s): 7766924

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -109
app.py CHANGED
@@ -5,97 +5,83 @@ from sklearn.cluster import KMeans
5
  from sklearn.preprocessing import StandardScaler
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
- import requests
9
- from io import StringIO
10
 
11
  # App title
12
  st.title("🛍️ Customer Segmentation Tool")
13
 
14
- # 🎯 Streamlit Tabs
15
  tab1, tab2, tab3, tab4 = st.tabs(["📖 About", "📊 Dataset Overview", "🧑‍🤝‍🧑 Customer Segmentation", "📥 Download Dataset"])
16
 
17
  # About Tab
18
  with tab1:
19
  st.write("""
20
- This app uses unsupervised learning techniques to segment customers based on their purchasing behavior.
21
- The dataset is uploaded by the user, containing online retail data.
22
- ### How It Works:
23
- - **Step 1**: Upload customer transaction data, including details like `Quantity`, `UnitPrice`, and `CustomerID`.
24
- - **Step 2**: Process the data by calculating the total spent and aggregating the information by customer.
25
- - **Step 3**: Apply **K-Means Clustering** to segment the customers into distinct groups.
26
- - **Step 4**: Visualize the customer segments with a scatter plot, and optionally download the segmented data.
27
  """)
28
 
29
  # File uploader in the Dataset Tab
30
  with tab2:
31
- uploaded_file = st.file_uploader("Upload Your Dataset", type=["csv", "xlsx"])
32
-
33
- if uploaded_file is not None:
34
- try:
35
- # Check file type
36
- if uploaded_file.name.endswith('.csv'):
37
- # Read CSV file with error handling for malformed lines
38
- df = pd.read_csv(uploaded_file, encoding='ISO-8859-1', on_bad_lines='skip')
39
- elif uploaded_file.name.endswith('.xlsx'):
40
- # Read Excel file
41
- df = pd.read_excel(uploaded_file)
42
- else:
43
- st.error("Unsupported file format. Please upload a CSV or Excel file.")
44
- st.stop()
45
-
46
- st.write("### Dataset Overview")
47
- st.write(df.head())
48
- except Exception as e:
49
- st.error(f"Error loading dataset: {e}")
50
- st.stop()
51
-
52
- # Automatically detect possible columns
53
- st.write("### Columns detected in your dataset:")
54
- st.write(df.columns.tolist())
55
-
56
- # Allow the user to map columns
57
- customer_col = st.selectbox("Select Customer Column", df.columns.tolist(), index=df.columns.tolist().index("CustomerID") if "CustomerID" in df.columns else 0)
58
- quantity_col = st.selectbox("Select Quantity Column", df.columns.tolist(), index=df.columns.tolist().index("Quantity") if "Quantity" in df.columns else 0)
59
- unit_price_col = st.selectbox("Select Unit Price Column", df.columns.tolist(), index=df.columns.tolist().index("UnitPrice") if "UnitPrice" in df.columns else 0)
60
-
61
- # Check if the selected columns exist
62
- if customer_col not in df.columns or quantity_col not in df.columns or unit_price_col not in df.columns:
63
- st.error("One or more selected columns do not exist in the dataset. Please select valid columns.")
64
- st.stop()
65
-
66
- # Preprocess data
67
- df = df.dropna(subset=[customer_col]) # Remove rows without CustomerID
68
- df["TotalSpent"] = pd.to_numeric(df[quantity_col], errors='coerce') * pd.to_numeric(df[unit_price_col], errors='coerce') # Ensure numeric type
69
 
70
- # Ensure that TotalSpent column is not NaN
 
 
 
 
 
 
 
 
 
 
71
  df = df.dropna(subset=["TotalSpent"])
72
 
73
- # Aggregate data by Customer
74
  customer_data = df.groupby(customer_col).agg({
75
  "TotalSpent": "sum",
76
  quantity_col: "sum",
77
  unit_price_col: "mean"
78
  }).rename(columns={quantity_col: "NumTransactions", unit_price_col: "AvgUnitPrice"})
79
 
80
- # Debug: Check if 'NumTransactions' exists in the DataFrame
81
- st.write("### Available columns in the aggregated customer data:")
82
- st.write(customer_data.columns.tolist())
83
-
84
- # Standardize the data
85
  scaler = StandardScaler()
86
  customer_scaled = pd.DataFrame(scaler.fit_transform(customer_data), columns=customer_data.columns, index=customer_data.index)
87
 
88
- # Customer Segmentation Tab
89
- with tab3:
90
- if uploaded_file is not None:
91
- # User selects the number of clusters
92
  num_clusters = st.slider("Select Number of Clusters", min_value=2, max_value=10, value=3)
93
-
94
- # Apply K-Means clustering
95
  model = KMeans(n_clusters=num_clusters, random_state=42)
96
  customer_data["Cluster"] = model.fit_predict(customer_scaled)
97
 
98
- # Visualize the clusters
99
  st.write("### Clusters Visualization")
100
  fig, ax = plt.subplots()
101
  scatter = ax.scatter(customer_data["TotalSpent"], customer_data["NumTransactions"], c=customer_data["Cluster"], cmap='viridis')
@@ -105,52 +91,7 @@ with tab3:
105
  plt.colorbar(scatter, label="Cluster")
106
  st.pyplot(fig)
107
 
108
- # Show the segmented customer data
109
- st.write("### Customer Segments Data")
110
- st.write(customer_data.head())
111
-
112
- # Option to download the segmented data
113
  csv = customer_data.to_csv(index=True)
114
- st.download_button(
115
- label="Download Segmented Customer Data",
116
- data=csv,
117
- file_name="segmented_customer_data.csv",
118
- mime="text/csv"
119
- )
120
  else:
121
- st.write("Please upload a dataset to start.")
122
-
123
- # Download Dataset Tab
124
- with tab4:
125
- st.write("""
126
- You can download the sample 'Online Retail' dataset to get started with customer segmentation tasks.
127
- Click the button below to download the dataset in CSV format.
128
- """)
129
-
130
- # Direct Google Drive link to the 'Online Retail' dataset (for direct download)
131
- dataset_url_online_retail = "https://drive.google.com/uc?id=1djBqO2sdHfy9DGZQXZu2Er8LUUXtp9Kr&export=download"
132
-
133
- # Direct Google Drive link to the new dataset (for direct download)
134
- dataset_url_new_file = "https://drive.google.com/uc?id=1PbGJSdcyDInsu-9Ua4iHzQh-YpVk_RqT&export=download"
135
-
136
- # Download the file from the URLs
137
- response_online_retail = requests.get(dataset_url_online_retail)
138
- file_data_online_retail = response_online_retail.text # Get the content as text
139
-
140
- response_new_file = requests.get(dataset_url_new_file)
141
- file_data_new_file = response_new_file.text # Get the content as text
142
-
143
- # Convert the CSV data into a CSV download for Streamlit
144
- st.download_button(
145
- label="Download Online Retail Dataset",
146
- data=file_data_online_retail,
147
- file_name="Online_Retail.csv",
148
- mime="text/csv"
149
- )
150
-
151
- st.download_button(
152
- label="Download New Dataset",
153
- data=file_data_new_file,
154
- file_name="New_Dataset.csv",
155
- mime="text/csv"
156
- )
 
5
  from sklearn.preprocessing import StandardScaler
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
 
 
8
 
9
  # App title
10
  st.title("🛍️ Customer Segmentation Tool")
11
 
12
+ # Streamlit Tabs
13
  tab1, tab2, tab3, tab4 = st.tabs(["📖 About", "📊 Dataset Overview", "🧑‍🤝‍🧑 Customer Segmentation", "📥 Download Dataset"])
14
 
15
  # About Tab
16
  with tab1:
17
  st.write("""
18
+ This app segments customers based on their purchasing behavior using unsupervised learning.
19
+ You can upload one or two datasets for analysis.
 
 
 
 
 
20
  """)
21
 
22
  # File uploader in the Dataset Tab
23
  with tab2:
24
+ uploaded_file1 = st.file_uploader("Upload First Dataset", type=["csv", "xlsx"], key="file1")
25
+ uploaded_file2 = st.file_uploader("Upload Second Dataset (Optional)", type=["csv", "xlsx"], key="file2")
26
+
27
+ def load_data(uploaded_file):
28
+ if uploaded_file is not None:
29
+ try:
30
+ if uploaded_file.name.endswith('.csv'):
31
+ df = pd.read_csv(uploaded_file, encoding='ISO-8859-1', on_bad_lines='skip')
32
+ elif uploaded_file.name.endswith('.xlsx'):
33
+ df = pd.read_excel(uploaded_file)
34
+ return df
35
+ except Exception as e:
36
+ st.error(f"Error loading dataset: {e}")
37
+ return None
38
+
39
+ df1 = load_data(uploaded_file1)
40
+ df2 = load_data(uploaded_file2)
41
+
42
+ if df1 is not None:
43
+ st.write("### First Dataset Overview")
44
+ st.write(df1.head())
45
+
46
+ if df2 is not None:
47
+ st.write("### Second Dataset Overview")
48
+ st.write(df2.head())
49
+
50
+ if df1 is not None and df2 is not None:
51
+ merge_option = st.radio("How would you like to combine the datasets?", ("Concatenate", "Keep Separate"))
52
+ if merge_option == "Concatenate":
53
+ df = pd.concat([df1, df2], ignore_index=True)
54
+ else:
55
+ df = None # Handle separately in clustering
56
+ else:
57
+ df = df1 if df1 is not None else df2
 
 
 
 
58
 
59
+ # Customer Segmentation Tab
60
+ with tab3:
61
+ if df is not None:
62
+ # Column selection
63
+ st.write("### Select Columns")
64
+ customer_col = st.selectbox("Select Customer Column", df.columns.tolist(), index=0)
65
+ quantity_col = st.selectbox("Select Quantity Column", df.columns.tolist(), index=0)
66
+ unit_price_col = st.selectbox("Select Unit Price Column", df.columns.tolist(), index=0)
67
+
68
+ df = df.dropna(subset=[customer_col])
69
+ df["TotalSpent"] = pd.to_numeric(df[quantity_col], errors='coerce') * pd.to_numeric(df[unit_price_col], errors='coerce')
70
  df = df.dropna(subset=["TotalSpent"])
71
 
 
72
  customer_data = df.groupby(customer_col).agg({
73
  "TotalSpent": "sum",
74
  quantity_col: "sum",
75
  unit_price_col: "mean"
76
  }).rename(columns={quantity_col: "NumTransactions", unit_price_col: "AvgUnitPrice"})
77
 
 
 
 
 
 
78
  scaler = StandardScaler()
79
  customer_scaled = pd.DataFrame(scaler.fit_transform(customer_data), columns=customer_data.columns, index=customer_data.index)
80
 
 
 
 
 
81
  num_clusters = st.slider("Select Number of Clusters", min_value=2, max_value=10, value=3)
 
 
82
  model = KMeans(n_clusters=num_clusters, random_state=42)
83
  customer_data["Cluster"] = model.fit_predict(customer_scaled)
84
 
 
85
  st.write("### Clusters Visualization")
86
  fig, ax = plt.subplots()
87
  scatter = ax.scatter(customer_data["TotalSpent"], customer_data["NumTransactions"], c=customer_data["Cluster"], cmap='viridis')
 
91
  plt.colorbar(scatter, label="Cluster")
92
  st.pyplot(fig)
93
 
 
 
 
 
 
94
  csv = customer_data.to_csv(index=True)
95
+ st.download_button("Download Segmented Customer Data", data=csv, file_name="segmented_customer_data.csv", mime="text/csv")
 
 
 
 
 
96
  else:
97
+ st.write("Please upload at least one dataset to start.")