| import streamlit as st |
| import pandas as pd |
| import joblib |
| import os |
|
|
| |
| |
| MODEL_FILENAME = "xgb_holiday_model.joblib" |
| COLUMNS_FILENAME = "model_columns.joblib" |
|
|
| |
| @st.cache_resource |
| def load_assets(): |
| |
| if os.path.exists(MODEL_FILENAME) and os.path.exists(COLUMNS_FILENAME): |
| model = joblib.load(MODEL_FILENAME) |
| model_cols = joblib.load(COLUMNS_FILENAME) |
| return model, model_cols |
| else: |
| st.error("Model or Columns file not found. Ensure 'xgb_holiday_model.joblib' and 'model_columns.joblib' are in the same folder.") |
| return None, None |
|
|
| model, model_columns = load_assets() |
|
|
| |
| st.title("Holiday Package Prediction App") |
| st.markdown("Enter customer details below to see the purchase probability.") |
|
|
| col1, col2 = st.columns(2) |
|
|
| with col1: |
| age = st.number_input("Age", min_value=18, max_value=100, value=30) |
| income = st.number_input("Monthly Income", min_value=1000, value=20000) |
| pitch_duration = st.number_input("Duration of Pitch (min)", min_value=5, value=15) |
| trips = st.number_input("Number of Trips", min_value=0, value=2) |
| children = st.number_input("Children Visiting", min_value=0, max_value=5, value=1) |
|
|
| with col2: |
| city_tier = st.selectbox("City Tier", [1, 2, 3]) |
| gender = st.selectbox("Gender", ["Female", "Male"]) |
| product = st.selectbox("Product Pitched", ["Basic", "Standard", "Deluxe", "Super Deluxe", "King"]) |
| marital = st.selectbox("Marital Status", ["Married", "Unmarried", "Divorced"]) |
|
|
| |
| def preprocess_input(age, income, pitch_duration, trips, children, city_tier, gender, product, marital): |
| |
| |
| data = { |
| 'Age': [age], |
| 'DurationOfPitch': [pitch_duration], |
| 'NumberOfTrips': [trips], |
| 'MonthlyIncome': [income], |
| 'CityTier': [city_tier], |
| 'NumberOfChildrenVisiting': [children], |
| |
| |
| 'Gender': [1 if gender == 'Male' else 0], |
| 'ProductPitched': [{'Basic':0, 'Standard':1, 'Deluxe':2, 'Super Deluxe':3, 'King':4}[product]], |
| |
| |
| 'MaritalStatus': [marital] |
| |
| } |
| |
| |
| df = pd.DataFrame(data) |
| |
| |
| |
| df = pd.get_dummies(df) |
|
|
| |
| |
| |
| if model_columns: |
| df = df.reindex(columns=model_columns, fill_value=0) |
| |
| return df |
|
|
| if st.button("Predict Purchase"): |
| if model is not None and model_columns is not None: |
| try: |
| |
| input_df = preprocess_input(age, income, pitch_duration, trips, children, city_tier, gender, product, marital) |
| |
| |
| prediction = model.predict(input_df) |
| prob = model.predict_proba(input_df)[0][1] |
|
|
| st.subheader("Results") |
| if prediction[0] == 1: |
| st.success(f"Prediction: **Likely to Buy** (Probability: {prob:.2%})") |
| else: |
| st.warning(f"Prediction: **Unlikely to Buy** (Probability: {prob:.2%})") |
| |
| except Exception as e: |
| st.error(f"Prediction Error: {e}") |
| else: |
| st.error("Model not loaded. Please check the files.") |