Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from sklearn.preprocessing import LabelEncoder, StandardScaler | |
| st.markdown("<h1 style='text-align: center; font-size: 48px; color: #6699CC;'>Next Day Rain Prediction</h1>", unsafe_allow_html=True) | |
| # Function to create cyclical features | |
| def create_date_features(df, date_column='Date'): | |
| df = df.copy() | |
| df[date_column] = pd.to_datetime(df[date_column]) | |
| # Extract basic components | |
| df['year'] = df[date_column].dt.year | |
| month = df[date_column].dt.month | |
| day = df[date_column].dt.day | |
| # Create cyclical features | |
| df['month_sin'] = np.sin(2 * np.pi * month/12) | |
| df['month_cos'] = np.cos(2 * np.pi * month/12) | |
| df['day_sin'] = np.sin(2 * np.pi * day/31) | |
| df['day_cos'] = np.cos(2 * np.pi * day/31) | |
| return df | |
| # Load the dataset | |
| def load_dataset(): | |
| df = pd.read_csv('weatherAUS.csv') | |
| return create_date_features(df) | |
| # Cache function to convert DataFrame to CSV | |
| def convert_df(df): | |
| return df.to_csv(index=False).encode("utf-8") | |
| # Define the neural network model | |
| class Enhanced_ANN_Model(nn.Module): | |
| def __init__(self, input_dim): | |
| super(Enhanced_ANN_Model, self).__init__() | |
| self.fc1 = nn.Linear(input_dim, 128) | |
| self.bn1 = nn.BatchNorm1d(128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.bn2 = nn.BatchNorm1d(64) | |
| self.fc3 = nn.Linear(64, 32) | |
| self.bn3 = nn.BatchNorm1d(32) | |
| self.fc4 = nn.Linear(32, 1) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.bn1(x) | |
| x = torch.relu(x) | |
| x = self.fc2(x) | |
| x = self.bn2(x) | |
| x = torch.relu(x) | |
| x = self.fc3(x) | |
| x = self.bn3(x) | |
| x = torch.relu(x) | |
| x = self.fc4(x) | |
| return x | |
| # Load pre-trained model | |
| def load_model(): | |
| input_dim = 26 # Changed to 26 features to match the trained model | |
| model = Enhanced_ANN_Model(input_dim) | |
| try: | |
| state_dict = torch.load("model_weights.pth", map_location=torch.device('cpu')) | |
| if isinstance(state_dict, dict): | |
| model.load_state_dict(state_dict) | |
| else: | |
| model = state_dict | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| st.markdown(f"<p style='color: #0000FF;'>Error loading model: {str(e)}</p>", unsafe_allow_html=True) | |
| return None | |
| # Load dataset | |
| try: | |
| df = load_dataset() | |
| # Display dataset preview | |
| st.markdown("<h3 style='color: #6699CC;'>Dataset Preview:</h3>", unsafe_allow_html=True) | |
| st.dataframe(df.head()) | |
| # Base required columns | |
| base_columns = ['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine', | |
| 'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm', | |
| 'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm', | |
| 'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', | |
| 'Temp9am', 'Temp3pm', 'RainToday'] | |
| # Add date-derived features | |
| required_columns = base_columns + ['month_sin', 'month_cos', 'day_sin', 'day_cos', 'year'] | |
| if not all(col in df.columns for col in required_columns): | |
| missing_columns = ', '.join(set(required_columns) - set(df.columns)) | |
| st.markdown(f"<p style='color: #6699CC;'>Missing required columns: {missing_columns}</p>", unsafe_allow_html=True) | |
| else: | |
| # Label Encoding for categorical columns | |
| label_encoders = {} | |
| categorical_cols = ['Location', 'WindGustDir', 'WindDir9am', 'WindDir3pm', 'RainToday'] | |
| for col in categorical_cols: | |
| le = LabelEncoder() | |
| df[col] = df[col].fillna('missing') | |
| df[col] = le.fit_transform(df[col].astype(str)) | |
| label_encoders[col] = le | |
| # Standard Scaling for numerical features | |
| scaler = StandardScaler() | |
| numerical_cols = [col for col in required_columns if col not in categorical_cols] | |
| df[numerical_cols] = df[numerical_cols].fillna(df[numerical_cols].mean()) | |
| df[numerical_cols] = scaler.fit_transform(df[numerical_cols]) | |
| # Select a row for prediction | |
| st.markdown("<h3 style='color: #6699CC;'>Select a Row for Prediction:</h3>", unsafe_allow_html=True) | |
| st.markdown(""" | |
| <style> | |
| .stSelectbox label { | |
| color: #ff6347; /* Set your desired color here */ | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Selectbox widget | |
| selected_row_index = st.selectbox("Select a Row Index", options=range(len(df)), index=0) | |
| predict_button = st.button("Predict Weather") | |
| if predict_button: | |
| model = load_model() | |
| if model is not None: | |
| # Get all required columns for prediction | |
| row_to_use = df.iloc[selected_row_index][required_columns] | |
| # Ensure all values are float32 | |
| row_tensor = torch.tensor(row_to_use.values.astype(np.float32)).unsqueeze(0) | |
| # Make prediction | |
| with torch.no_grad(): | |
| prediction = model(row_tensor).item() | |
| # Apply sigmoid to get probability | |
| prediction = torch.sigmoid(torch.tensor(prediction)).item() | |
| # Display results | |
| st.markdown("<h3 style='color: #32a852;'>Row selected for prediction:</h3>", unsafe_allow_html=True) | |
| st.write(row_to_use) | |
| result = "Rain Expected" if prediction >= 0.5 else "No Rain Expected" | |
| probability = prediction * 100 | |
| st.markdown(f"<h3 style='color: #32a852;'>Rain Prediction Result: {result}</h3>", unsafe_allow_html=True) | |
| st.markdown(f"<h3 style='color: #32a852;'>Probability of Rain: {probability:.2f}%</h3>", unsafe_allow_html=True) | |
| # Show original date for reference | |
| original_date = df.iloc[selected_row_index]['Date'] | |
| st.markdown(f"<h3 style='color: #32a852;'>Date: {original_date}</h3>", unsafe_allow_html=True) | |
| # Provide download option | |
| result_df = row_to_use.to_frame().T | |
| result_df['Rain Prediction'] = result | |
| result_df['Rain Probability'] = f"{probability:.2f}%" | |
| result_df['Date'] = original_date | |
| result_csv = convert_df(result_df) | |
| st.download_button( | |
| label="Download Prediction Result", | |
| data=result_csv, | |
| file_name="Rain_Prediction_Result.csv", | |
| mime="text/csv", | |
| ) | |
| except Exception as e: | |
| st.markdown(f"<p style='color: #32a852;'>An error occurred: {str(e)}</p>", unsafe_allow_html=True) | |