Anjini-Katari commited on
Commit
df26fa6
Β·
verified Β·
1 Parent(s): 487933f

Upload 4 files

Browse files
src/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎈 Blank app template
2
+
3
+ A simple Streamlit app template for you to modify!
4
+
5
+ [![Open in Streamlit](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://blank-app-template.streamlit.app/)
6
+
7
+ ### How to run it on your own machine
8
+
9
+ 1. Install the requirements
10
+
11
+ ```
12
+ $ pip install -r requirements.txt
13
+ ```
14
+
15
+ 2. Run the app
16
+
17
+ ```
18
+ $ streamlit run streamlit_app.py
19
+ ```
src/ocd_patient_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/requirements.txt ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ alembic==1.16.2
2
+ altair==5.5.0
3
+ annotated-types==0.7.0
4
+ anyio==4.9.0
5
+ asttokens==3.0.0
6
+ attrs==25.3.0
7
+ blinker==1.9.0
8
+ cachetools==5.5.2
9
+ category_encoders==2.7.0
10
+ certifi==2025.6.15
11
+ charset-normalizer==3.4.2
12
+ choreographer==1.0.9
13
+ click==8.2.1
14
+ cloudpickle==3.1.1
15
+ comm==0.2.2
16
+ contourpy==1.3.2
17
+ cycler==0.12.1
18
+ Cython==3.1.2
19
+ dash==3.1.0
20
+ databricks-sdk==0.57.0
21
+ decorator==5.2.1
22
+ deprecation==2.1.0
23
+ docker==7.1.0
24
+ executing==2.2.0
25
+ fastapi==0.115.14
26
+ fastjsonschema==2.21.1
27
+ filelock==3.18.0
28
+ Flask==3.1.1
29
+ fonttools==4.58.4
30
+ fsspec==2025.5.1
31
+ gitdb==4.0.12
32
+ GitPython==3.1.41
33
+ google-auth==2.40.3
34
+ graphene==3.4.3
35
+ graphql-core==3.2.6
36
+ graphql-relay==3.2.0
37
+ greenlet==3.2.3
38
+ gunicorn==23.0.0
39
+ h11==0.16.0
40
+ idna==3.10
41
+ imbalanced-learn==0.13.0
42
+ importlib_metadata==8.7.0
43
+ ipython==9.3.0
44
+ ipython_pygments_lexers==1.1.1
45
+ ipywidgets==8.1.7
46
+ itsdangerous==2.2.0
47
+ jedi==0.19.2
48
+ Jinja2==3.1.6
49
+ joblib==1.3.2
50
+ jsonschema==4.24.0
51
+ jsonschema-specifications==2025.4.1
52
+ jupyter_core==5.8.1
53
+ jupyterlab_widgets==3.0.15
54
+ kaleido==1.0.0
55
+ kiwisolver==1.4.8
56
+ lightgbm==4.6.0
57
+ llvmlite==0.44.0
58
+ logistro==1.1.0
59
+ Mako==1.3.10
60
+ MarkupSafe==3.0.2
61
+ matplotlib==3.7.5
62
+ matplotlib-inline==0.1.7
63
+ mlflow==3.1.1
64
+ mlflow-skinny==3.1.1
65
+ mpmath==1.3.0
66
+ narwhals==1.44.0
67
+ nbformat==5.10.4
68
+ nest-asyncio==1.6.0
69
+ networkx==3.5
70
+ numba==0.61.2
71
+ numpy==1.26.4
72
+ nvidia-cublas-cu12==12.6.4.1
73
+ nvidia-cuda-cupti-cu12==12.6.80
74
+ nvidia-cuda-nvrtc-cu12==12.6.77
75
+ nvidia-cuda-runtime-cu12==12.6.77
76
+ nvidia-cudnn-cu12==9.5.1.17
77
+ nvidia-cufft-cu12==11.3.0.4
78
+ nvidia-cufile-cu12==1.11.1.6
79
+ nvidia-curand-cu12==10.3.7.77
80
+ nvidia-cusolver-cu12==11.7.1.2
81
+ nvidia-cusparse-cu12==12.5.4.2
82
+ nvidia-cusparselt-cu12==0.6.3
83
+ nvidia-nccl-cu12==2.26.2
84
+ nvidia-nvjitlink-cu12==12.6.85
85
+ nvidia-nvtx-cu12==12.6.77
86
+ opentelemetry-api==1.34.1
87
+ opentelemetry-sdk==1.34.1
88
+ opentelemetry-semantic-conventions==0.55b1
89
+ orjson==3.10.18
90
+ packaging==25.0
91
+ pandas==2.1.4
92
+ parso==0.8.4
93
+ patsy==1.0.1
94
+ pexpect==4.9.0
95
+ pillow==11.2.1
96
+ platformdirs==4.3.8
97
+ plotly==5.24.1
98
+ plotly-express==0.4.1
99
+ plotly-resampler==0.10.0
100
+ pmdarima==2.0.4
101
+ prompt_toolkit==3.0.51
102
+ protobuf==6.31.1
103
+ psutil==7.0.0
104
+ ptyprocess==0.7.0
105
+ pure_eval==0.2.3
106
+ pyarrow==20.0.0
107
+ pyasn1==0.6.1
108
+ pyasn1_modules==0.4.2
109
+ pycaret==3.3.2
110
+ pydantic==2.11.7
111
+ pydantic_core==2.33.2
112
+ pydeck==0.9.1
113
+ Pygments==2.19.2
114
+ pyod==2.0.5
115
+ pyparsing==3.2.3
116
+ python-dateutil==2.9.0.post0
117
+ pytz==2025.2
118
+ PyYAML==6.0.2
119
+ referencing==0.36.2
120
+ requests==2.32.4
121
+ retrying==1.4.0
122
+ rpds-py==0.25.1
123
+ rsa==4.9.1
124
+ schemdraw==0.15
125
+ scikit-base==0.7.8
126
+ scikit-learn==1.4.2
127
+ scikit-plot==0.3.7
128
+ scipy==1.11.4
129
+ seaborn==0.13.2
130
+ shap==0.48.0
131
+ simplejson==3.20.1
132
+ six==1.17.0
133
+ sklearn-compat==0.1.3
134
+ sktime==0.26.0
135
+ slicer==0.0.8
136
+ smmap==5.0.2
137
+ sniffio==1.3.1
138
+ SQLAlchemy==2.0.41
139
+ sqlparse==0.5.3
140
+ stack-data==0.6.3
141
+ starlette==0.46.2
142
+ statsmodels==0.14.4
143
+ streamlit==1.46.1
144
+ sympy==1.14.0
145
+ tbats==1.1.3
146
+ tenacity==9.1.2
147
+ threadpoolctl==3.6.0
148
+ toml==0.10.2
149
+ torch==2.7.1
150
+ tornado==6.5.1
151
+ tqdm==4.67.1
152
+ traitlets==5.14.3
153
+ triton==3.3.1
154
+ tsdownsample==0.1.4.1
155
+ typing-inspection==0.4.1
156
+ typing_extensions==4.14.0
157
+ tzdata==2025.2
158
+ urllib3==2.5.0
159
+ uvicorn==0.34.3
160
+ watchdog==6.0.0
161
+ wcwidth==0.2.13
162
+ Werkzeug==3.1.3
163
+ widgetsnbextension==4.0.14
164
+ wurlitzer==3.1.1
165
+ xxhash==3.5.0
166
+ yellowbrick==1.5
167
+ zipp==3.23.0
src/streamlit_app.py CHANGED
@@ -1,40 +1,525 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+
 
 
2
  import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+
8
+ import plotly.express as px
9
+ import plotly.graph_objects as go
10
+ from datetime import datetime
11
+ import warnings
12
+ from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
13
+ from sklearn.preprocessing import LabelEncoder, StandardScaler, MinMaxScaler
14
+ from sklearn.linear_model import LinearRegression, LogisticRegression
15
+ from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
16
+ from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
17
+ from sklearn.svm import SVC, SVR
18
+ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
19
+ from sklearn.naive_bayes import GaussianNB
20
+ from sklearn.metrics import (
21
+ mean_squared_error, mean_absolute_error, r2_score,
22
+ accuracy_score, precision_score, recall_score, f1_score,
23
+ confusion_matrix, classification_report, roc_auc_score
24
+ )
25
+ warnings.filterwarnings('ignore')
26
+
27
+ # MLflow and experiment tracking
28
+ try:
29
+ import mlflow
30
+ import mlflow.sklearn
31
+ MLFLOW_AVAILABLE = True
32
+ except ImportError:
33
+ MLFLOW_AVAILABLE = False
34
+ st.warning("MLflow not installed. Some features may be limited.")
35
+
36
+ # PyCaret imports
37
+ try:
38
+ from pycaret.classification import setup as cls_setup, compare_models as cls_compare, create_model as cls_create
39
+ from pycaret.classification import tune_model as cls_tune, finalize_model as cls_finalize, predict_model as cls_predict
40
+ from pycaret.classification import pull as cls_pull, plot_model as cls_plot, evaluate_model as cls_evaluate
41
+ from pycaret.regression import setup as reg_setup, compare_models as reg_compare, create_model as reg_create
42
+ from pycaret.regression import tune_model as reg_tune, finalize_model as reg_finalize, predict_model as reg_predict
43
+ from pycaret.regression import pull as reg_pull, plot_model as reg_plot, evaluate_model as reg_evaluate
44
+ PYCARET_AVAILABLE = True
45
+ except ImportError:
46
+ PYCARET_AVAILABLE = False
47
+ st.warning("PyCaret not installed. AutoML features will be limited.")
48
+
49
+ # Data profiling
50
+ #try:
51
+ # from ydata_profiling import ProfileReport
52
+ # from streamlit_pandas_profiling import st_profile_report
53
+ # PROFILING_AVAILABLE = True
54
+ #except ImportError:
55
+ # PROFILING_AVAILABLE = False
56
+
57
+ # PyTorch for deep learning
58
+ try:
59
+ import torch
60
+ import torch.nn as nn
61
+ import torch.optim as optim
62
+ from torch.utils.data import TensorDataset, DataLoader
63
+ TORCH_AVAILABLE = True
64
+ except ImportError:
65
+ TORCH_AVAILABLE = False
66
+
67
+ # SHAP for explainability
68
+ try:
69
+ import shap
70
+ SHAP_AVAILABLE = True
71
+ except ImportError:
72
+ SHAP_AVAILABLE = False
73
+ # Scikit-learn imports
74
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
75
+ from sklearn.linear_model import LinearRegression, LogisticRegression
76
+ from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
77
+ from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
78
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
79
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
80
+
81
+ # ================== UPLOADING THE DATA ==================
82
+
83
+ df = pd.read_csv("ocd_patient_dataset.csv")
84
+
85
+ # ================== CUSTOM CSS & STYLING ==================
86
+ st.set_page_config(
87
+ page_title="OCD Diagnosing",
88
+ layout="wide",
89
+ initial_sidebar_state="expanded",
90
+ page_icon="πŸš€"
91
+ )
92
+
93
+ st.markdown("""
94
+ <style>
95
+ /* Main styling */
96
+ .main {
97
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
98
+ font-family: 'Arial', sans-serif;
99
+ }
100
+
101
+ /* Sidebar styling */
102
+ .sidebar .sidebar-content {
103
+ background: linear-gradient(180deg, #2C3E50, #3498DB);
104
+ color: white;
105
+ }
106
+
107
+ /* Button styling */
108
+ .stButton > button {
109
+ background: linear-gradient(45deg, #FF6B6B, #4ECDC4);
110
+ color: white;
111
+ border: none;
112
+ border-radius: 25px;
113
+ padding: 0.6rem 1.5rem;
114
+ font-weight: bold;
115
+ transition: all 0.3s ease;
116
+ box-shadow: 0 4px 15px 0 rgba(31, 38, 135, 0.37);
117
+ }
118
+
119
+ .stButton > button:hover {
120
+ transform: translateY(-2px);
121
+ box-shadow: 0 8px 25px 0 rgba(31, 38, 135, 0.37);
122
+ }
123
+
124
+ /* Metric styling */
125
+ .metric-container {
126
+ background: rgba(255, 255, 255, 0.1);
127
+ backdrop-filter: blur(10px);
128
+ border-radius: 15px;
129
+ padding: 1rem;
130
+ margin: 0.5rem 0;
131
+ border: 1px solid rgba(255, 255, 255, 0.2);
132
+ }
133
+
134
+ /* Header styling */
135
+ .main-header {
136
+ text-align: center;
137
+ padding: 2rem 0;
138
+ background: rgba(255, 255, 255, 0.1);
139
+ backdrop-filter: blur(10px);
140
+ border-radius: 20px;
141
+ margin-bottom: 2rem;
142
+ border: 1px solid rgba(255, 255, 255, 0.2);
143
+ }
144
+
145
+ /* Success/Error messages */
146
+ .stSuccess, .stError, .stWarning {
147
+ border-radius: 10px;
148
+ border: none;
149
+ }
150
+ </style>
151
+ """, unsafe_allow_html=True)
152
+
153
+ # ================== HEADER ==================
154
+ st.markdown("""
155
+ <div class="main-header">
156
+ <h1 style="color: white; font-size: 3rem; margin-bottom: 0;">OCD Diagnosing</h1>
157
+ <p style="color: rgba(255,255,255,0.8); font-size: 1.2rem;">
158
+ Test different factors on their predicibility of OCD using ML Models
159
+ </p>
160
+ </div>
161
+ """, unsafe_allow_html=True)
162
+
163
+ # ================== AUTHENTICATION ==================
164
+ def check_authentication():
165
+ if 'authenticated' not in st.session_state:
166
+ st.session_state.authenticated = False
167
+
168
+ if not st.session_state.authenticated:
169
+ with st.sidebar:
170
+ st.header("πŸ”’ Authentication")
171
+ password = st.text_input("Enter Password", type="password", key="auth_password")
172
+ col1, col2 = st.columns(2)
173
+ with col1:
174
+ if st.button("πŸ”‘ Login", key="login_btn"):
175
+ if password == "diagnosis testing":
176
+ st.session_state.authenticated = True
177
+ st.success("βœ… Access Granted!")
178
+ st.rerun()
179
+ else:
180
+ st.error("❌ Incorrect Password")
181
+ with col2:
182
+ if st.button("πŸ‘€ Demo Mode", key="demo_btn"):
183
+ st.session_state.authenticated = True
184
+ st.session_state.demo_mode = True
185
+ st.info("πŸ“Š Demo Mode Activated")
186
+ st.rerun()
187
+
188
+ st.info("πŸ” Please authenticate to access the application")
189
+ st.stop()
190
+
191
+ check_authentication()
192
+
193
+ # ================== SESSION STATE INITIALIZATION ==================
194
+ if 'df' not in st.session_state:
195
+ st.session_state.df = None
196
+ if 'trained_models' not in st.session_state:
197
+ st.session_state.trained_models = {}
198
+ if 'pycaret_setup_done' not in st.session_state:
199
+ st.session_state.pycaret_setup_done = False
200
+ if 'best_model' not in st.session_state:
201
+ st.session_state.best_model = None
202
+ if 'dl_models' not in st.session_state:
203
+ st.session_state.dl_models = {}
204
+ if 'training_history' not in st.session_state:
205
+ st.session_state.training_history = {}
206
+
207
+ # ================== SIDEBAR NAVIGATION ==================
208
+ #PAGES
209
+ st.sidebar.title("🧭 Navigation")
210
+ pages = [
211
+ "🏠 Home",
212
+ "πŸ“Š Data Viz",
213
+ "πŸ€– Logistical Regression",
214
+ "🌳 Decision Tree",
215
+ "Model Comparison"
216
+ ]
217
+ #"πŸ“‹ MLflow Tracking",
218
+
219
+ selected_page = st.sidebar.selectbox("Select Page", pages, key="page_selector")
220
+
221
+
222
+ # ================== PAGE CONTENT ==================
223
+
224
+ if selected_page == "🏠 Home":
225
+ col1, col2, col3 = st.columns([1, 2, 1])
226
+
227
+ with col2:
228
+ st.markdown("""
229
+ ## OCD Diagnosis Deep Dive
230
+
231
+ About the data
232
+ There is an ongoing issue of misdiagnosis among mental illnesses, like OCD. Machine Learning has the ability to make diagnosing easier.
233
+ This app aims to use factors such as OCD Diagnosis Date, Duration of Symptoms in months, Previous Diagnoses, Family History of OCD,
234
+ Obsession Type, and Compulsion Type, to see if we accurately predict the obession and/or compulsion type.
235
+
236
+ """)
237
+
238
+ st.table(df.head())
239
+
240
+
241
+ #DATA VIZ
242
+ elif selected_page == "πŸ“Š Data Viz":
243
+ filtds = df.drop(columns=["Patient ID"])
244
+
245
+ col_x = st.selectbox("Select X-axis variable (group by)", filtds.columns)
246
+ col_y = st.selectbox("Select Y-axis variable (numeric)", filtds.columns)
247
+
248
+ tab1, tab2, tab3, tab4 = st.tabs(["Box plot", "Bar Chart πŸ“Š","Line Chart πŸ“ˆ","Correlation Heatmap πŸ”₯",])
249
+
250
+ with tab1:
251
+ st.subheader("Box plot")
252
+ fig, ax = plt.subplots()
253
+ sns.boxplot(data=df, x=col_x, y=col_y, ax=ax)
254
+ ax.set_title(f"{col_y} by {col_x}")
255
+ st.pyplot(fig)
256
+
257
+ with tab2:
258
+ st.subheader("Bar Chart")
259
+ st.bar_chart(df[[col_x,col_y]].sort_values(by=col_x),use_container_width=True)
260
+
261
+ with tab3:
262
+ st.subheader("Line Chart")
263
+ st.line_chart(df[[col_x,col_y]].sort_values(by=col_x),use_container_width=True)
264
+
265
+ with tab4:
266
+ st.subheader("Correlation Matrix")
267
+ df_numeric = df.select_dtypes(include=np.number)
268
+
269
+ ct = pd.crosstab(df[col_x], df[col_y])
270
+ sns.heatmap(ct, annot=True, fmt='d', cmap='Blues')
271
+ plt.xlabel(col_y)
272
+ plt.ylabel(col_x)
273
+ plt.title(f"Heatmap of {col_x} vs {col_y}")
274
+
275
+ #LOG REG
276
+ elif selected_page == "πŸ€– Logistical Regression":
277
+ st.header("Running a Logistical Regression on our data...")
278
+
279
+ target_variable = st.selectbox(
280
+ "Select which variable you would like to predict:",
281
+ ["Y-BOCS Score (Obsessions)", "Y-BOCS Score (Compulsions)", "Depression Diagnosis", "Anxiety Diagnosis"]
282
+ )
283
+ if st.button("Train Model"):
284
+ with st.spinner("Training model..."):
285
+ try:
286
+ df_sampled = df.sample(n=500, random_state=42)
287
+ X = df_sampled.drop(columns=[target_variable])
288
+ X = X.select_dtypes(include=["number"])
289
+ y = df_sampled[target_variable]
290
+
291
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
292
+
293
+ scaler = StandardScaler()
294
+ X_train_scaled = scaler.fit_transform(X_train)
295
+ X_test_scaled = scaler.transform(X_test)
296
+
297
+ model = LogisticRegression()
298
+ model.fit(X_train_scaled, y_train)
299
+
300
+ y_pred = model.predict(X_test_scaled)
301
+
302
+ st.write("### Accuracy:", accuracy_score(y_test, y_pred))
303
+ st.write("### Classification Report:")
304
+ st.text(classification_report(y_test, y_pred))
305
+
306
+ st.subheader("πŸ“Š SHAP Summary Plot for Logistic Regression")
307
+ fig2 = shap.plots.beeswarm(shap_values, show=False)
308
+ st.pyplot(bbox_inches='tight')
309
+ plt.clf()
310
+
311
+ except Exception as e:
312
+ st.error(f"❌ Error training model: {str(e)}")
313
+
314
+ elif selected_page == "🌳 Decision Tree":
315
+ st.header("Predictions via decision tree...")
316
+
317
+ target_variable = st.selectbox(
318
+ "Select which variable you would like to predict:",
319
+ ["Y-BOCS Score (Obsessions)", "Y-BOCS Score (Compulsions)", "Depression Diagnosis", "Anxiety Diagnosis"]
320
+ )
321
+
322
+ X = df.drop(columns=[target_variable])
323
+ X = X.select_dtypes(include=["number"])
324
+ X = X.fillna(X.mean())
325
+ y = df[target_variable]
326
+
327
+ # split
328
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
329
+
330
+ # train tree
331
+ dt_model = DecisionTreeClassifier(max_depth=5, random_state=42) # You can adjust depth
332
+ dt_model.fit(X_train, y_train)
333
+
334
+ y_pred = dt_model.predict(X_test)
335
+
336
+ st.write("### 🌳 Decision Tree Performance")
337
+ st.write("**Accuracy:**", accuracy_score(y_test, y_pred))
338
+ st.write("**Classification Report:**")
339
+ st.text(classification_report(y_test, y_pred))
340
+
341
+ explainer = shap.Explainer(dt_model, X_test)
342
+ shap_values = explainer(X_test)
343
+
344
+ # Summary plot (global feature importance)
345
+ st.subheader("πŸ“Š SHAP Summary Plot")
346
+ fig1 = shap.plots.beeswarm(shap_values, show=False)
347
+ st.pyplot(bbox_inches='tight')
348
+ plt.clf()
349
+
350
+
351
+ elif selected_page == "Model Comparison":
352
+ st.header("Decision Tree vs Logistic Regression")
353
+
354
+ target_variable = st.selectbox(
355
+ "🎯 Select the target variable to predict:",
356
+ ["Y-BOCS Score (Obsessions)", "Y-BOCS Score (Compulsions)", "Depression Diagnosis", "Anxiety Diagnosis"])
357
+
358
+ df_sampled = df.sample(n=500, random_state=42)
359
+ X = df_sampled.drop(columns=[target_variable])
360
+ X = X.select_dtypes(include=["number"])
361
+ X = X.fillna(X.mean())
362
+ y = df_sampled[target_variable]
363
+
364
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
365
+
366
+ scaler = StandardScaler()
367
+ X_train_scaled = scaler.fit_transform(X_train)
368
+ X_test_scaled = scaler.transform(X_test)
369
+
370
+ logreg = LogisticRegression(max_iter=1000)
371
+ dtree = DecisionTreeClassifier(max_depth=5, random_state=42)
372
+
373
+ mlflow.set_tracking_uri("http://127.0.0.1:5000")
374
+ mlflow.set_experiment("OCD")
375
+
376
+ col1, col2 = st.columns(2)
377
+ with col1:
378
+ with mlflow.start_run(run_name="Decision Tree"):
379
+ dtree.fit(X_train, y_train)
380
+ y_pred_tree = dtree.predict(X_test)
381
+ y_proba_tree = dtree.predict_proba(X_test)[:, 1]
382
+
383
+ st.markdown("### 🌿 Decision Tree")
384
+ st.write("**Accuracy:**", accuracy_score(y_test, y_pred_tree))
385
+ st.text(classification_report(y_test, y_pred_tree))
386
+
387
+ cm_tree = confusion_matrix(y_test, y_pred_tree)
388
+ fig1, ax1 = plt.subplots()
389
+ sns.heatmap(cm_tree, annot=True, fmt='d', cmap='Greens', ax=ax1)
390
+ ax1.set_title("Decision Tree Confusion Matrix")
391
+ st.pyplot(fig1)
392
+ plt.close(fig1)
393
+
394
+ st.session_state.trained_models = st.session_state.get("trained_models", {})
395
+ st.session_state.trained_models["Decision Tree"] = {
396
+ "model": dtree,
397
+ "features": X.columns.tolist(),
398
+ "target": target_variable,
399
+ "predictions": y_pred_tree,
400
+ "y_test": y_test,
401
+ "problem_type": "Classification"
402
+ }
403
+
404
+
405
+ with col2:
406
+ with mlflow.start_run(run_name="Logistic Regression"):
407
+ logreg.fit(X_train_scaled, y_train)
408
+ y_pred_log = logreg.predict(X_test_scaled)
409
+ y_proba_log = logreg.predict_proba(X_test_scaled)[:, 1]
410
+
411
+ st.markdown("### πŸ“ˆ Logistic Regression")
412
+ st.write("**Accuracy:**", accuracy_score(y_test, y_pred_log))
413
+ st.text(classification_report(y_test, y_pred_log))
414
+
415
+ cm_log = confusion_matrix(y_test, y_pred_log)
416
+ fig2, ax2 = plt.subplots()
417
+ sns.heatmap(cm_log, annot=True, fmt='d', cmap='Blues', ax=ax2)
418
+ ax2.set_title("Logistic Regression Confusion Matrix")
419
+ st.pyplot(fig2)
420
+ plt.close(fig2)
421
+
422
+ st.session_state.trained_models = st.session_state.get("trained_models", {})
423
+ st.session_state.trained_models["Logistic Regression"] = {
424
+ "model": logreg,
425
+ "features": X.columns.tolist(),
426
+ "target": target_variable,
427
+ "predictions": y_pred_log,
428
+ "y_test": y_test,
429
+ "problem_type": "Classification"
430
+ }
431
+
432
+
433
+
434
+ # elif selected_page == "πŸ“‹ MLflow Tracking":
435
+ # st.header("πŸ“‹ MLflow Experiment Tracking")
436
+
437
+ # # --- MLflow config section ---
438
+ # st.subheader("βš™οΈ MLflow Configuration")
439
+ # tracking_uri = st.text_input("πŸ”— Tracking URI", value="http://localhost:5000")
440
+ # experiment_name = st.text_input("πŸ§ͺ Experiment Name", value="my_local_experiment")
441
+
442
+ # if st.button("πŸ”§ Set MLflow Configuration"):
443
+ # try:
444
+ # mlflow.set_tracking_uri(tracking_uri)
445
+ # mlflow.set_experiment(experiment_name)
446
+ # st.success("βœ… MLflow configured successfully!")
447
+ # except Exception as e:
448
+ # st.error(f"❌ Failed to set MLflow config: {str(e)}")
449
+
450
+ # # --- Log trained model ---
451
+ # st.subheader("πŸ“€ Log Trained Model to MLflow")
452
+
453
+ # if st.session_state.get("trained_models"):
454
+ # model_name = st.selectbox("Select a model to log:", list(st.session_state.trained_models.keys()))
455
+ # if st.button("πŸ“₯ Log This Model"):
456
+ # model_data = st.session_state.trained_models[model_name]
457
+ # try:
458
+ # with mlflow.start_run(run_name=f"{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
459
+ # # Log model
460
+ # mlflow.sklearn.log_model(model_data["model"], "model")
461
+
462
+ # # Log params
463
+ # mlflow.log_param("model_type", model_name)
464
+ # mlflow.log_param("target", model_data["target"])
465
+ # mlflow.log_param("features", len(model_data["features"]))
466
+
467
+ # # Log metrics
468
+ # y_test = model_data["y_test"]
469
+ # y_pred = model_data["predictions"]
470
+ # if model_data["problem_type"] == "Classification":
471
+ # acc = accuracy_score(y_test, y_pred)
472
+ # mlflow.log_metric("accuracy", acc)
473
+ # else:
474
+ # mlflow.log_metric("mse", mean_squared_error(y_test, y_pred))
475
+ # mlflow.log_metric("r2", r2_score(y_test, y_pred))
476
+ # mlflow.log_metric("mae", mean_absolute_error(y_test, y_pred))
477
+
478
+ # st.success("βœ… Model logged to MLflow!")
479
+ # except Exception as e:
480
+ # st.error(f"❌ Error logging model: {str(e)}")
481
+ # else:
482
+ # st.info("No models found. Train some models first!")
483
+
484
+ # # --- View past runs ---
485
+ # st.subheader("πŸ“ˆ Recent Experiment Runs")
486
+
487
+ # if st.button("πŸ”„ Refresh Runs"):
488
+ # try:
489
+ # runs_df = mlflow.search_runs(order_by=["start_time desc"])
490
+ # if not runs_df.empty:
491
+ # st.dataframe(
492
+ # runs_df[[
493
+ # 'run_id',
494
+ # 'status',
495
+ # 'start_time',
496
+ # 'params.model_type',
497
+ # 'params.target',
498
+ # 'metrics.accuracy', # This will show NaN for regression
499
+ # 'metrics.mse',
500
+ # 'metrics.r2'
501
+ # ]],
502
+ # use_container_width=True
503
+ # )
504
+ # else:
505
+ # st.info("πŸ“Š No runs found.")
506
+ # except Exception as e:
507
+ # st.error(f"❌ Error fetching runs: {str(e)}")
508
+
509
+
510
+
511
+ # ================== SIDEBAR ! ==================
512
+
513
+ # Help section
514
+ st.sidebar.markdown("---")
515
+ st.sidebar.subheader("Where to go...")
516
+ st.sidebar.markdown("""
517
+ 1. 🏠 Home
518
+ 2. πŸ“Š Data Viz
519
+ 3. πŸ€– Logistical Regression
520
+ 4. 🌳 Decision Tree
521
+ 5. Model Comparison
522
+
523
+ """)
524
 
525
+ #6. πŸ“‹ MLflow Tracking