| import streamlit as st |
| import pandas as pd |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from sklearn.cluster import KMeans |
| import joblib |
|
|
| @st.cache_resource |
| def load_dataset(path: str) -> pd.DataFrame: |
| """ |
| Opens up a `.csv` file as our main dataset. |
| |
| Args: |
| path (str): Path to the dataset to be opened |
| |
| Returns: |
| pd.DataFrame: A pandas dataframe |
| """ |
| |
| df = pd.read_csv(path) |
| |
| dates = pd.to_datetime(df['Date']) |
| df['Date'] = dates |
|
|
| df['Year'] = df['Date'].dt.year |
| df['Month'] = df['Date'].dt.month |
| df['Day'] = df['Date'].dt.day |
|
|
| df.drop('Date', axis=1, inplace=True) |
| |
| return df |
|
|
| @st.cache_resource |
| def load_model(model_path: str) -> KMeans: |
| return joblib.load(model_path) |
|
|
| def process_data(input_X: pd.DataFrame) -> np.array: |
| """ |
| Processes user input data into usable form for the KMeans model to predict. |
| |
| Args: |
| input_X (pd.DataFrame): Input data in dataframe format with one instance. |
| |
| Returns: |
| np.array: An numpy array for the KMeans model to predict |
| """ |
| |
| input_X = input_X.copy() |
| |
| |
| if 'Date' in input_X.columns: |
| dates = pd.to_datetime(input_X['Date']) |
| input_X['Date'] = dates |
|
|
| input_X['Year'] = input_X['Date'].dt.year |
| input_X['Month'] = input_X['Date'].dt.month |
| input_X['Day'] = input_X['Date'].dt.day |
|
|
| input_X.drop('Date', axis=1, inplace=True) |
| |
| input_X = pd.get_dummies(input_X, prefix=['Source'], dtype=int) |
| |
| for col in ['Source_GCAG', 'Source_GISTEMP']: |
| if col not in input_X.columns: |
| input_X[col] = 0 |
|
|
| |
| input_X = input_X[['Mean', 'Year', 'Month', 'Day', 'Source_GCAG', 'Source_GISTEMP']] |
| |
| arr_X = input_X.to_numpy() |
| |
| return arr_X |
| |
| def plot_clusters(model: KMeans, X: np.array, input_X: np.array) -> None: |
| """ |
| Plots the predicted class to the clusters. |
| |
| Args: |
| model (KMeans): A KMeans model trained on X input |
| X (np.array): The numpy array version of the dataset |
| input_X (np.array): The numpy array of the input |
| |
| Returns: |
| None |
| """ |
| centroids = model.cluster_centers_ |
| labels = model.labels_ |
| |
| fig = plt.figure(figsize=(10,6)) |
| ax = fig.subplots() |
|
|
| for cluster in range(3): |
| cluster_points = X[labels == cluster] |
| plt.scatter(cluster_points[:, 0], cluster_points[:, 1], label=f'Cluster {cluster}') |
|
|
| |
| ax.scatter(centroids[:, 0], centroids[:, 1], s=200, c='black', marker='X', label='Centroids') |
| |
| predictions = model.predict(input_X) |
| |
| st.write(f"Predicted Cluster: {predictions[0]}") |
|
|
| |
| ax.scatter(input_X[:, 0], input_X[:, 1], s=300, c='red', marker='P', label=f'Predicted Cluster: {predictions[0]}') |
|
|
| ax.set_title('K-Means Clustering with Predicted Point') |
| ax.legend() |
| st.pyplot(fig) |
|
|
| def main(): |
| st.title("Global Temperature Time Series") |
| df = load_dataset("./monthly_csv.csv") |
| |
| tab1, tab2 = st.tabs(["KMeans Prediction", "About the Dataset"]) |
| |
| with tab1: |
| st.header("Input Data") |
| |
| source = st.selectbox("Choose your source platform.", tuple(df['Source'].unique())) |
| mean_temp = st.slider("Choose your avg. temp", df['Mean'].min(), df['Mean'].max()) |
| date = st.date_input("Choose date to monitor air quality", min_value="1980-01-01", max_value=None) |
| |
| |
| model = None |
| if st.button("Predict Input!"): |
| d = pd.DataFrame({"Source": [source], |
| "Mean": [mean_temp], |
| "Date": [date] |
| }) |
| input_X = process_data(d) |
| |
| model = load_model('./models/kmeans_model.pkl') |
| if model is not None: |
| processed_df = process_data(df) |
| plot_clusters(model, processed_df, input_X) |
| |
| with tab2: |
| st.caption("Global Temperature Time Series. Data are included from the GISS Surface Temperature (GISTEMP) analysis and the global component of Climate at a Glance (GCAG). Two datasets are provided: 1) global monthly mean and 2) annual mean temperature anomalies in degrees Celsius from 1880 to the present.") |
| |
| st.write("Citation:") |
| st.caption("GISTEMP: NASA Goddard Institute for Space Studies (GISS) Surface Temperature Analysis, Global Land-Ocean Temperature Index.") |
| st.caption("NOAA National Climatic Data Center (NCDC), global component of Climate at a Glance (GCAG).)") |
|
|
| if __name__ == "__main__": |
| main() |