RabiesIQ / src /streamlit_app.py
Iman Kozly
Update src/streamlit_app.py
3bd3ce3 verified
import os
import streamlit as st
import pandas as pd
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# --- Load model and encoders ---
'''
model = joblib.load("DB_model_IMAN.pkl")
cat_encoders = joblib.load("categorical_encoders.pkl")
num_encoders = joblib.load("numerical_scaler.pkl")
df = pd.read_excel('Rabies__Weather__War_Combined_1.4.25.xlsx')
'''
model = joblib.load("./src/DB_model_IMAN.pkl")
cat_encoders = joblib.load("./src/categorical_encoders.pkl")
num_encoders = joblib.load("./src/numerical_scaler.pkl")
df = pd.read_excel("./src/Rabies__Weather__War_Combined_1.4.25.xlsx")
# --- Streamlit page config ---
st.set_page_config(
page_title="Rabies Multi-Target Prediction",
layout="wide",
page_icon="🦠"
)
# --- Custom CSS for modern look ---
st.markdown("""
<style>
/* Background gradient */
.stApp {
background: linear-gradient(to right, #3399ff, #66ccff);
color: #333;
}
/* Titles and headers */
h1, h2, h3 {
color: #004d99;
font-family: 'Arial', sans-serif;
font-weight: bold;
}
/* Sidebar style */
.stSidebar .css-1d391kg {
background-color: #ccffcc;
padding: 20px;
border-radius: 10px;
}
/* Buttons */
.stButton>button {
background-color: #007700;
color: white;
font-weight: bold;
padding: 0.6em 1.2em;
border-radius: 8px;
border: none;
transition: 0.3s;
}
.stButton>button:hover {
background-color: #004d00;
}
/* DataFrame display */
[data-testid="stDataFrame"] {
border-radius: 10px;
box-shadow: 0px 4px 6px rgba(0,0,0,0.1);
}
</style>
""", unsafe_allow_html=True)
# --- Input UI ---
st.title("Rabies Multi-Target Prediction Dashboard")
st.header("Predict New Record")
# --- Collect inputs ---
#numbers
year = st.number_input("Year", min_value=1900, max_value=2030, value=2025)
x_coord = st.number_input("x coordinate", value=34.9896) #מזרח
y_coord = st.number_input("y coordinate", value=32.7940) # צפון
avg_temp = st.number_input("Avg Temperature", value=25.0)
monthly_precip = st.number_input("Monthly Precipitation (mm)", value=50.0)
rainy_days = st.number_input("Rainy Days", value=4)
animal_species = st.selectbox("Animal Species", sorted(df['Animal Species'].dropna().unique().tolist()))
rabies_species = st.selectbox("Rabies Species", sorted(df['Rabies Species'].dropna().unique().tolist()))
settlement = st.selectbox("Settlement", sorted(df['Settlement'].dropna().unique().tolist()))
region_weather = st.selectbox("Region Weather", sorted(df['Region_Weather'].dropna().unique().tolist()))
war_in_israel = st.selectbox("War in Israel", ["No", "Yes"])
# --- Create input DataFrame ---
feature_names = ['Year', 'Animal Species', 'Rabies Species', 'Settlement', 'x', 'y',
'Region_Weather', 'Avg Temperature', 'Monthly Precipitation (mm)',
'Rainy Days', 'War in Israel']
input_df = pd.DataFrame([[
year, animal_species, rabies_species, settlement, x_coord, y_coord,
region_weather, avg_temp, monthly_precip, rainy_days, war_in_israel
]], columns=feature_names)
# --- Encode & scale ---
categorical_features = ['Year','Animal Species', 'Rabies Species', 'Settlement', 'Region_Weather','War in Israel']
numerical_features = ['x', 'y', 'Avg Temperature', 'Monthly Precipitation (mm)', 'Rainy Days']
for col in categorical_features:
if col in cat_encoders:
input_df[col] = cat_encoders[col].transform(input_df[col].astype(str))
if num_encoders is not None:
input_df[numerical_features] = num_encoders.transform(input_df[numerical_features])
st.subheader("Input Data")
st.dataframe(input_df)
input_df['War in Israel'] = input_df['War in Israel'].map({'Yes': 1, 'No': 0})
# --- Prediction ---
if st.button("Predict"):
try:
prediction = model.predict(input_df)
st.subheader("Predicted Values with Confidence")
if hasattr(model, "predict_proba"):
proba_list = model.predict_proba(input_df)
target_names = ['Region', 'Month']
for i, target in enumerate(target_names):
probs = proba_list[i][0] # הסתברויות לכל הקטגוריות
categories = cat_encoders[target].classes_
# יצירת DataFrame
df_probs = pd.DataFrame({
"Category": categories,
"Probability (%)": probs * 100
}).sort_values(by="Probability (%)", ascending=False)
# מציג את החיזוי הראשון (הגבוה ביותר)
top_category = df_probs.iloc[0]['Category']
top_prob = df_probs.iloc[0]['Probability (%)']
st.write(f"**Top Prediction for {target}: {top_category}{top_prob:.2f}%**")
st.subheader(f"{target} Prediction Probabilities")
# --- Bar chart of all probabilities ---
df_probs = pd.DataFrame({
"Category": categories,
"Probability (%)": probs * 100
}).sort_values(by="Probability (%)", ascending=False)
fig_prob, ax_prob = plt.subplots(figsize=(8, 4))
sns.barplot(
x="Category",
y="Probability (%)",
data=df_probs,
palette="bright",
ax=ax_prob
)
ax_prob.set_title(f"{target} Prediction Probabilities", fontsize=14, fontweight='bold')
ax_prob.set_xlabel("Category", fontsize=12)
ax_prob.set_ylabel("Probability (%)", fontsize=12)
ax_prob.set_ylim(0, 100)
plt.xticks(rotation=45, ha='right')
ax_prob.grid(alpha=0.3)
plt.tight_layout()
st.pyplot(fig_prob)
else:
st.success(f"Prediction: {prediction}")
except Exception as e:
st.error(f"Error during prediction: {e}")
# --- Run Streamlit (Hugging Face Spaces) ---
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860)) # ברירת מחדל 7860
#st.run(host="0.0.0.0", port=port)