arichar14 commited on
Commit
372d0e2
·
verified ·
1 Parent(s): ef92502

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -1
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import pandas as pd
2
  from sklearn.cluster import KMeans
3
  import matplotlib.pyplot as plt
4
  import gradio as gr
@@ -48,4 +48,63 @@ demo = gr.Interface(
48
  description="Upload a CSV of cities with AvgMonthlyTourists, AvgTemp, and Hotels. Choose number of clusters to group similar cities."
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  demo.launch()
 
1
+ '''import pandas as pd
2
  from sklearn.cluster import KMeans
3
  import matplotlib.pyplot as plt
4
  import gradio as gr
 
48
  description="Upload a CSV of cities with AvgMonthlyTourists, AvgTemp, and Hotels. Choose number of clusters to group similar cities."
49
  )
50
 
51
+ demo.launch()'''
52
+
53
+ import pandas as pd
54
+ from sklearn.cluster import KMeans
55
+ from sklearn.preprocessing import StandardScaler
56
+ import matplotlib.pyplot as plt
57
+ import gradio as gr
58
+ import tempfile
59
+
60
+ def cluster_tourism(file, n_clusters):
61
+ # Load CSV
62
+ df = pd.read_csv(file)
63
+
64
+ # Features to cluster on
65
+ features = df[['AvgMonthlyTourists', 'AvgTemp', 'Hotels']]
66
+
67
+ # Standardize features
68
+ scaler = StandardScaler()
69
+ features_scaled = scaler.fit_transform(features)
70
+
71
+ # KMeans clustering
72
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
73
+ df['Cluster'] = kmeans.fit_predict(features_scaled)
74
+
75
+ # Save clustered CSV to temporary file
76
+ tmp_csv = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
77
+ df.to_csv(tmp_csv.name, index=False)
78
+
79
+ # Plot clusters
80
+ plt.figure(figsize=(6,4))
81
+ for cluster in range(n_clusters):
82
+ subset = df[df['Cluster'] == cluster]
83
+ plt.scatter(subset['AvgMonthlyTourists'], subset['AvgTemp'], label=f'Cluster {cluster}')
84
+ plt.xlabel('Avg Monthly Tourists')
85
+ plt.ylabel('Avg Temp')
86
+ plt.title('City Clusters')
87
+ plt.legend()
88
+
89
+ # Save plot to temporary file
90
+ tmp_plot = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
91
+ plt.savefig(tmp_plot.name)
92
+ plt.close()
93
+
94
+ return tmp_csv.name, tmp_plot.name
95
+
96
+ demo = gr.Interface(
97
+ fn=cluster_tourism,
98
+ inputs=[
99
+ gr.File(file_types=[".csv"], type="filepath", label="Upload CSV"),
100
+ gr.Slider(minimum=2, maximum=10, step=1, label="Number of Clusters")
101
+ ],
102
+ outputs=[
103
+ gr.File(label="CSV with Cluster Labels"),
104
+ gr.Image(label="Cluster Plot")
105
+ ],
106
+ title="City Clustering",
107
+ description="Upload a CSV of cities with AvgMonthlyTourists, AvgTemp, and Hotels. Choose number of clusters to group similar cities."
108
+ )
109
+
110
  demo.launch()