AlignAI commited on
Commit
8d3494c
Β·
verified Β·
1 Parent(s): b8b9ebc

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +154 -34
src/streamlit_app.py CHANGED
@@ -1,49 +1,169 @@
1
  import streamlit as st
2
  import joblib
3
  import numpy as np
 
4
 
5
- # Load the trained model and scaler
6
- # Use @st.cache_resource to load them only once for performance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  @st.cache_resource
8
  def load_model():
9
- model = joblib.load('src/svm_model.pkl')
10
- scaler = joblib.load('src/scaler.pkl')
11
- return model, scaler
12
-
13
- try:
14
- model, scaler = load_model()
15
- except FileNotFoundError:
16
- st.error("Model files not found. Please run train_model.py first.")
 
 
 
 
 
 
 
 
 
 
17
  st.stop()
18
 
19
- st.title("Purchase Intention Predictor")
20
- st.write("Adjust the sliders below to predict the user's Purchase Intention (PI).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Sidebar for inputs
23
- st.sidebar.header("User Inputs")
24
 
25
- # Create sliders for each feature based on the data's 1-7 scale
26
- att = st.sidebar.slider("Attitude (ATT)", min_value=1.0, max_value=7.0, value=4.0, step=0.1)
27
- sns = st.sidebar.slider("Subjective Norms (SNs)", min_value=1.0, max_value=7.0, value=4.0, step=0.1)
28
- pbc = st.sidebar.slider("Perceived Behavioral Control (PBC)", min_value=1.0, max_value=7.0, value=4.0, step=0.1)
29
- eo = st.sidebar.slider("Environmental Outcome (EO)", min_value=1.0, max_value=7.0, value=4.0, step=0.1)
30
- ec = st.sidebar.slider("Environmental Concern (EC)", min_value=1.0, max_value=7.0, value=4.0, step=0.1)
31
 
32
- # Prepare input data
33
- input_data = np.array([[att, sns, pbc, eo, ec]])
 
 
 
34
 
35
- # Scale the input
36
- input_scaled = scaler.transform(input_data)
 
 
37
 
38
- # Predict
39
- if st.button("Predict Purchase Intention"):
40
- prediction = model.predict(input_scaled)
41
- st.subheader(f"Predicted Purchase Intention Score: {prediction[0]:.2f}")
42
 
43
- # Optional: Interpretation
44
- if prediction[0] > 5.5:
45
- st.success("High Purchase Intention")
46
- elif prediction[0] < 3.5:
47
- st.warning("Low Purchase Intention")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  else:
49
- st.info("Moderate Purchase Intention")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import joblib
3
  import numpy as np
4
+ import plotly.graph_objects as go
5
 
6
+ # 1. Page Configuration (Must be the first command)
7
+ st.set_page_config(
8
+ page_title="Purchase Intention AI",
9
+ page_icon="πŸ›οΈ",
10
+ layout="wide",
11
+ initial_sidebar_state="expanded"
12
+ )
13
+
14
+ # Custom CSS for styling
15
+ st.markdown("""
16
+ <style>
17
+ .stButton>button {
18
+ width: 100%;
19
+ background-color: #FF4B4B;
20
+ color: white;
21
+ font-weight: bold;
22
+ padding: 0.5rem;
23
+ border-radius: 10px;
24
+ }
25
+ .main-header {
26
+ font-size: 2.5rem;
27
+ font-weight: 700;
28
+ color: #333;
29
+ text-align: center;
30
+ margin-bottom: 2rem;
31
+ }
32
+ .sub-text {
33
+ text-align: center;
34
+ color: #666;
35
+ font-size: 1.1rem;
36
+ }
37
+ </style>
38
+ """, unsafe_allow_html=True)
39
+
40
+ # 2. Load Model
41
  @st.cache_resource
42
  def load_model():
43
+ # Handles both the pipeline (smart) version and separate files version
44
+ try:
45
+ # Try loading the smart pipeline first
46
+ model = joblib.load('src/svm_model.pkl')
47
+ return model, None # No separate scaler needed
48
+ except:
49
+ try:
50
+ # Fallback to separate files
51
+ model = joblib.load('src/svm_model.pkl')
52
+ scaler = joblib.load('src/scaler.pkl')
53
+ return model, scaler
54
+ except FileNotFoundError:
55
+ return None, None
56
+
57
+ model, scaler = load_model()
58
+
59
+ if model is None:
60
+ st.error("🚨 Model files not found! Please run `train_model.py` first.")
61
  st.stop()
62
 
63
+ # 3. Header Section
64
+ st.markdown('<div class="main-header">πŸ›οΈ Purchase Intention Predictor</div>', unsafe_allow_html=True)
65
+ st.markdown('<p class="sub-text">Adjust the psychometric drivers in the sidebar to predict user behavior.</p>', unsafe_allow_html=True)
66
+ st.markdown("---")
67
+
68
+ # 4. Sidebar - User Inputs
69
+ st.sidebar.header("🧠 Psychometric Profiling")
70
+ st.sidebar.markdown("Adjust the behavioral scores (1-7 scale):")
71
+
72
+ def create_slider(label, key, help_text):
73
+ return st.sidebar.slider(
74
+ label,
75
+ min_value=1.0,
76
+ max_value=7.0,
77
+ value=4.5,
78
+ step=0.1,
79
+ help=help_text
80
+ )
81
+
82
+ att = create_slider("Attitude (ATT)", "att", "The user's positive or negative feelings toward the behavior.")
83
+ sns = create_slider("Subjective Norms (SNs)", "sns", "Social pressure or influence from others to perform the behavior.")
84
+ pbc = create_slider("Perceived Control (PBC)", "pbc", "The user's perception of the ease or difficulty of performing the behavior.")
85
+ eo = create_slider("Env. Outcome (EO)", "eo", "Expected environmental benefits resulting from the behavior.")
86
+ ec = create_slider("Env. Concern (EC)", "ec", "General concern for environmental issues.")
87
 
88
+ # 5. Main Content Area
89
+ col1, col2 = st.columns([1, 1.5])
90
 
91
+ # Prepare Input
92
+ input_values = np.array([[att, sns, pbc, eo, ec]])
 
 
 
 
93
 
94
+ # Handle Scaling
95
+ if scaler:
96
+ final_input = scaler.transform(input_values)
97
+ else:
98
+ final_input = input_values # Pipeline handles scaling internally
99
 
100
+ # Real-time Prediction (or on button click)
101
+ prediction = model.predict(final_input)[0]
102
+ # Clip prediction to 1-7 range for visuals
103
+ prediction = max(1.0, min(7.0, prediction))
104
 
105
+ with col1:
106
+ st.subheader("πŸ“Š Prediction Result")
 
 
107
 
108
+ # Gauge Chart for PI
109
+ fig_gauge = go.Figure(go.Indicator(
110
+ mode = "gauge+number",
111
+ value = prediction,
112
+ domain = {'x': [0, 1], 'y': [0, 1]},
113
+ title = {'text': "Purchase Intention (PI)"},
114
+ gauge = {
115
+ 'axis': {'range': [1, 7]},
116
+ 'bar': {'color': "#FF4B4B"},
117
+ 'steps': [
118
+ {'range': [1, 3.5], 'color': "#f8f9fa"},
119
+ {'range': [3.5, 5.5], 'color': "#e9ecef"},
120
+ {'range': [5.5, 7], 'color': "#dee2e6"}
121
+ ],
122
+ 'threshold': {
123
+ 'line': {'color': "red", 'width': 4},
124
+ 'thickness': 0.75,
125
+ 'value': prediction
126
+ }
127
+ }
128
+ ))
129
+ fig_gauge.update_layout(height=350, margin=dict(l=20,r=20,t=50,b=20))
130
+ st.plotly_chart(fig_gauge, use_container_width=True)
131
+
132
+ # Text Interpretation
133
+ if prediction >= 5.5:
134
+ st.success("**High Probability**: User is very likely to purchase.")
135
+ elif prediction >= 3.5:
136
+ st.info("**Moderate Probability**: User is undecided.")
137
  else:
138
+ st.warning("**Low Probability**: User is unlikely to purchase.")
139
+
140
+ with col2:
141
+ st.subheader("πŸ•ΈοΈ User Profile Analysis")
142
+
143
+ # Radar Chart for Inputs
144
+ categories = ['Attitude', 'Social Norms', 'Control', 'Outcome', 'Concern']
145
+ r_values = [att, sns, pbc, eo, ec]
146
+
147
+ fig_radar = go.Figure()
148
+ fig_radar.add_trace(go.Scatterpolar(
149
+ r=r_values,
150
+ theta=categories,
151
+ fill='toself',
152
+ name='User Profile',
153
+ line_color='#00CC96'
154
+ ))
155
+
156
+ fig_radar.update_layout(
157
+ polar=dict(
158
+ radialaxis=dict(
159
+ visible=True,
160
+ range=[0, 7]
161
+ )),
162
+ showlegend=False,
163
+ height=350,
164
+ margin=dict(l=40,r=40,t=20,b=20)
165
+ )
166
+ st.plotly_chart(fig_radar, use_container_width=True)
167
+
168
+ st.markdown("---")
169
+ st.markdown("###### *Model: Support Vector Machine (RBF Kernel) | Data Scale: 1 (Strongly Disagree) - 7 (Strongly Agree)*")