AlignAI's picture
Update src/streamlit_app.py
52ee0e5 verified
import streamlit as st
import joblib
import numpy as np
import plotly.graph_objects as go
# 1. Page Configuration
st.set_page_config(
page_title="Purchase Intention AI",
page_icon="πŸ›οΈ",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for styling
# UPDATED: Added 'color: white;' to .main-header and .sub-text
st.markdown("""
<style>
/* Main background dark mode adjustment (optional, for contrast) */
.stApp {
background-color: #0E1117;
}
.stButton>button {
width: 100%;
background-color: #FF4B4B;
color: white;
font-weight: bold;
padding: 0.5rem;
border-radius: 10px;
}
/* TARGET 1: Main Title */
.main-header {
font-size: 2.5rem;
font-weight: 700;
color: white !important; /* Forced white color */
text-align: center;
margin-bottom: 1rem;
}
/* TARGET 2: Sub-header text */
.sub-text {
text-align: center;
color: white !important; /* Forced white color */
font-size: 1.1rem;
margin-bottom: 2rem;
}
</style>
""", unsafe_allow_html=True)
# 2. Load Model
@st.cache_resource
def load_model():
try:
# Try loading the smart pipeline first
model = joblib.load('src/svm_model.pkl')
return model, None
except:
try:
# Fallback to separate files
model = joblib.load('src/svm_model.pkl')
scaler = joblib.load('src/scaler.pkl')
return model, scaler
except FileNotFoundError:
return None, None
model, scaler = load_model()
if model is None:
st.error("🚨 Model files not found! Please run `train_model.py` first.")
st.stop()
# 3. Header Section (Applies the CSS classes defined above)
st.markdown('<div class="main-header">πŸ›οΈ Purchase Intention Predictor</div>', unsafe_allow_html=True)
st.markdown('<p class="sub-text">Adjust the psychometric drivers in the sidebar to predict user behavior.</p>', unsafe_allow_html=True)
st.markdown("---")
# 4. Sidebar - User Inputs
st.sidebar.header("🧠 Psychometric Profiling")
st.sidebar.markdown("Adjust the behavioral scores (1-7 scale):")
def create_slider(label, key, help_text):
return st.sidebar.slider(
label,
min_value=1.0,
max_value=7.0,
value=4.5,
step=0.1,
help=help_text
)
att = create_slider("Attitude (ATT)", "att", "The user's positive or negative feelings toward the behavior.")
sns = create_slider("Subjective Norms (SNs)", "sns", "Social pressure or influence from others to perform the behavior.")
pbc = create_slider("Perceived Control (PBC)", "pbc", "The user's perception of the ease or difficulty of performing the behavior.")
eo = create_slider("Env. Outcome (EO)", "eo", "Expected environmental benefits resulting from the behavior.")
ec = create_slider("Env. Concern (EC)", "ec", "General concern for environmental issues.")
# 5. Main Content Area
col1, col2 = st.columns([1, 1.5])
# Prepare Input
input_values = np.array([[att, sns, pbc, eo, ec]])
# Handle Scaling
if scaler:
final_input = scaler.transform(input_values)
else:
final_input = input_values
# Real-time Prediction
prediction = model.predict(final_input)[0]
prediction = max(1.0, min(7.0, prediction))
with col1:
st.subheader("πŸ“Š Prediction Result")
# Gauge Chart
fig_gauge = go.Figure(go.Indicator(
mode = "gauge+number",
value = prediction,
domain = {'x': [0, 1], 'y': [0, 1]},
title = {'text': "Purchase Intention (PI)"},
gauge = {
'axis': {'range': [1, 7]},
'bar': {'color': "#FF4B4B"},
'steps': [
{'range': [1, 3.5], 'color': "#f8f9fa"},
{'range': [3.5, 5.5], 'color': "#e9ecef"},
{'range': [5.5, 7], 'color': "#dee2e6"}
],
'threshold': {
'line': {'color': "red", 'width': 4},
'thickness': 0.75,
'value': prediction
}
}
))
fig_gauge.update_layout(height=350, margin=dict(l=20,r=20,t=50,b=20))
st.plotly_chart(fig_gauge, use_container_width=True)
# Text Interpretation
if prediction >= 5.5:
st.success("High Probability: User is likely to purchase.")
elif prediction >= 3.5:
st.info("Moderate Probability: User is undecided.")
else:
st.warning("Low Probability: User is unlikely to purchase.")
with col2:
st.subheader("πŸ•ΈοΈ User Profile Analysis")
# Radar Chart
categories = ['Attitude', 'Social Norms', 'Control', 'Outcome', 'Concern']
r_values = [att, sns, pbc, eo, ec]
fig_radar = go.Figure()
fig_radar.add_trace(go.Scatterpolar(
r=r_values,
theta=categories,
fill='toself',
name='User Profile',
line_color='#00CC96'
))
fig_radar.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 7]
)),
showlegend=False,
height=350,
margin=dict(l=40,r=40,t=20,b=20)
)
st.plotly_chart(fig_radar, use_container_width=True)
st.markdown("---")
st.markdown("###### *Model: Support Vector Machine (RBF Kernel) | Data Scale: 1 (Strongly Disagree) - 7 (Strongly Agree)*")