Spaces:
Sleeping
Sleeping
File size: 7,007 Bytes
c96612d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import streamlit as st
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import LabelEncoder, StandardScaler
st.markdown("<h1 style='text-align: center; font-size: 48px; color: #6699CC;'>Next Day Rain Prediction</h1>", unsafe_allow_html=True)
# Function to create cyclical features
def create_date_features(df, date_column='Date'):
df = df.copy()
df[date_column] = pd.to_datetime(df[date_column])
# Extract basic components
df['year'] = df[date_column].dt.year
month = df[date_column].dt.month
day = df[date_column].dt.day
# Create cyclical features
df['month_sin'] = np.sin(2 * np.pi * month/12)
df['month_cos'] = np.cos(2 * np.pi * month/12)
df['day_sin'] = np.sin(2 * np.pi * day/31)
df['day_cos'] = np.cos(2 * np.pi * day/31)
return df
# Load the dataset
@st.cache_data
def load_dataset():
df = pd.read_csv('weatherAUS.csv')
return create_date_features(df)
# Cache function to convert DataFrame to CSV
@st.cache_data
def convert_df(df):
return df.to_csv(index=False).encode("utf-8")
# Define the neural network model
class Enhanced_ANN_Model(nn.Module):
def __init__(self, input_dim):
super(Enhanced_ANN_Model, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.bn1 = nn.BatchNorm1d(128)
self.fc2 = nn.Linear(128, 64)
self.bn2 = nn.BatchNorm1d(64)
self.fc3 = nn.Linear(64, 32)
self.bn3 = nn.BatchNorm1d(32)
self.fc4 = nn.Linear(32, 1)
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.fc2(x)
x = self.bn2(x)
x = torch.relu(x)
x = self.fc3(x)
x = self.bn3(x)
x = torch.relu(x)
x = self.fc4(x)
return x
# Load pre-trained model
@st.cache_resource
def load_model():
input_dim = 26 # Changed to 26 features to match the trained model
model = Enhanced_ANN_Model(input_dim)
try:
state_dict = torch.load("model_weights.pth", map_location=torch.device('cpu'))
if isinstance(state_dict, dict):
model.load_state_dict(state_dict)
else:
model = state_dict
model.eval()
return model
except Exception as e:
st.markdown(f"<p style='color: #0000FF;'>Error loading model: {str(e)}</p>", unsafe_allow_html=True)
return None
# Load dataset
try:
df = load_dataset()
# Display dataset preview
st.markdown("<h3 style='color: #6699CC;'>Dataset Preview:</h3>", unsafe_allow_html=True)
st.dataframe(df.head())
# Base required columns
base_columns = ['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm',
'Temp9am', 'Temp3pm', 'RainToday']
# Add date-derived features
required_columns = base_columns + ['month_sin', 'month_cos', 'day_sin', 'day_cos', 'year']
if not all(col in df.columns for col in required_columns):
missing_columns = ', '.join(set(required_columns) - set(df.columns))
st.markdown(f"<p style='color: #6699CC;'>Missing required columns: {missing_columns}</p>", unsafe_allow_html=True)
else:
# Label Encoding for categorical columns
label_encoders = {}
categorical_cols = ['Location', 'WindGustDir', 'WindDir9am', 'WindDir3pm', 'RainToday']
for col in categorical_cols:
le = LabelEncoder()
df[col] = df[col].fillna('missing')
df[col] = le.fit_transform(df[col].astype(str))
label_encoders[col] = le
# Standard Scaling for numerical features
scaler = StandardScaler()
numerical_cols = [col for col in required_columns if col not in categorical_cols]
df[numerical_cols] = df[numerical_cols].fillna(df[numerical_cols].mean())
df[numerical_cols] = scaler.fit_transform(df[numerical_cols])
# Select a row for prediction
st.markdown("<h3 style='color: #6699CC;'>Select a Row for Prediction:</h3>", unsafe_allow_html=True)
st.markdown("""
<style>
.stSelectbox label {
color: #ff6347; /* Set your desired color here */
}
</style>
""", unsafe_allow_html=True)
# Selectbox widget
selected_row_index = st.selectbox("Select a Row Index", options=range(len(df)), index=0)
predict_button = st.button("Predict Weather")
if predict_button:
model = load_model()
if model is not None:
# Get all required columns for prediction
row_to_use = df.iloc[selected_row_index][required_columns]
# Ensure all values are float32
row_tensor = torch.tensor(row_to_use.values.astype(np.float32)).unsqueeze(0)
# Make prediction
with torch.no_grad():
prediction = model(row_tensor).item()
# Apply sigmoid to get probability
prediction = torch.sigmoid(torch.tensor(prediction)).item()
# Display results
st.markdown("<h3 style='color: #32a852;'>Row selected for prediction:</h3>", unsafe_allow_html=True)
st.write(row_to_use)
result = "Rain Expected" if prediction >= 0.5 else "No Rain Expected"
probability = prediction * 100
st.markdown(f"<h3 style='color: #32a852;'>Rain Prediction Result: {result}</h3>", unsafe_allow_html=True)
st.markdown(f"<h3 style='color: #32a852;'>Probability of Rain: {probability:.2f}%</h3>", unsafe_allow_html=True)
# Show original date for reference
original_date = df.iloc[selected_row_index]['Date']
st.markdown(f"<h3 style='color: #32a852;'>Date: {original_date}</h3>", unsafe_allow_html=True)
# Provide download option
result_df = row_to_use.to_frame().T
result_df['Rain Prediction'] = result
result_df['Rain Probability'] = f"{probability:.2f}%"
result_df['Date'] = original_date
result_csv = convert_df(result_df)
st.download_button(
label="Download Prediction Result",
data=result_csv,
file_name="Rain_Prediction_Result.csv",
mime="text/csv",
)
except Exception as e:
st.markdown(f"<p style='color: #32a852;'>An error occurred: {str(e)}</p>", unsafe_allow_html=True)
|