shauryaDugar commited on
Commit
fb3f4c9
·
verified ·
1 Parent(s): 9bc106f

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +58 -0
  2. website_new.py +329 -0
requirements.txt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.5.0
2
+ attrs==24.3.0
3
+ beautifulsoup4==4.12.3
4
+ blinker==1.9.0
5
+ cachetools==5.5.0
6
+ certifi==2024.12.14
7
+ charset-normalizer==3.4.1
8
+ click==8.1.8
9
+ filelock==3.16.1
10
+ fsspec==2024.12.0
11
+ gdown==4.6.0
12
+ gitdb==4.0.11
13
+ GitPython==3.1.43
14
+ idna==3.10
15
+ Jinja2==3.1.5
16
+ joblib==1.4.2
17
+ jsonschema==4.23.0
18
+ jsonschema-specifications==2024.10.1
19
+ markdown-it-py==3.0.0
20
+ MarkupSafe==3.0.2
21
+ mdurl==0.1.2
22
+ mpmath==1.3.0
23
+ narwhals==1.19.1
24
+ networkx==3.2.1
25
+ numpy==2.0.2
26
+ packaging==24.2
27
+ pandas==2.2.3
28
+ pillow==11.0.0
29
+ plotly==5.24.1
30
+ protobuf==5.29.2
31
+ pyarrow==18.1.0
32
+ pydeck==0.9.1
33
+ Pygments==2.18.0
34
+ PySocks==1.7.1
35
+ python-dateutil==2.9.0.post0
36
+ pytz==2024.2
37
+ referencing==0.35.1
38
+ requests==2.32.3
39
+ rich==13.9.4
40
+ rpds-py==0.22.3
41
+ scikit-learn==1.6.0
42
+ scipy==1.13.1
43
+ six==1.17.0
44
+ smmap==5.0.1
45
+ soupsieve==2.6
46
+ streamlit==1.41.1
47
+ sympy==1.13.1
48
+ tenacity==9.0.0
49
+ threadpoolctl==3.5.0
50
+ toml==0.10.2
51
+ torch==2.5.1
52
+ torchaudio==2.5.1
53
+ torchvision==0.20.1
54
+ tornado==6.4.2
55
+ tqdm==4.67.1
56
+ typing_extensions==4.12.2
57
+ tzdata==2024.2
58
+ urllib3==2.3.0
website_new.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+ import joblib
10
+ import os
11
+ import gdown
12
+ import tempfile
13
+ import shutil
14
+ import requests
15
+ import zipfile
16
+ from tqdm import tqdm
17
+
18
+ # Set page config
19
+ st.set_page_config(
20
+ page_title="Microbiome Symptom Predictor",
21
+ page_icon="🦠",
22
+ layout="wide"
23
+ )
24
+
25
+ class MicrobiomeNet(nn.Module):
26
+ def __init__(self, input_size=1024, hidden_size=128, output_size=2):
27
+ super(MicrobiomeNet, self).__init__()
28
+
29
+ # Feature attention network
30
+ self.feature_attention = nn.Sequential(
31
+ nn.Linear(input_size, hidden_size),
32
+ nn.ReLU(),
33
+ nn.Linear(hidden_size, 1)
34
+ )
35
+
36
+ # Abundance processing network
37
+ self.abundance_network = nn.Sequential(
38
+ nn.Linear(input_size, hidden_size),
39
+ nn.ReLU(),
40
+ nn.BatchNorm1d(hidden_size),
41
+ nn.Dropout(0.2),
42
+ nn.Linear(hidden_size, hidden_size)
43
+ )
44
+
45
+ # Interaction processing network
46
+ self.interaction_network = nn.Sequential(
47
+ nn.Linear(input_size, hidden_size),
48
+ nn.ReLU(),
49
+ nn.BatchNorm1d(hidden_size),
50
+ nn.Dropout(0.2),
51
+ nn.Linear(hidden_size, hidden_size)
52
+ )
53
+
54
+ # Final layers
55
+ self.final_layers = nn.Sequential(
56
+ nn.Linear(hidden_size * 2, hidden_size),
57
+ nn.ReLU(),
58
+ nn.BatchNorm1d(hidden_size),
59
+ nn.Dropout(0.2),
60
+ nn.Linear(hidden_size, output_size)
61
+ )
62
+
63
+ def forward(self, x):
64
+ # Apply feature attention
65
+ attention = torch.sigmoid(self.feature_attention(x))
66
+ x_attended = x * attention
67
+
68
+ # Process through parallel networks
69
+ abundance_features = self.abundance_network(x_attended)
70
+ interaction_features = self.interaction_network(x)
71
+
72
+ # Combine features
73
+ combined = torch.cat([abundance_features, interaction_features], dim=1)
74
+
75
+ # Final processing
76
+ output = self.final_layers(combined)
77
+ return output
78
+
79
+ def download_models_from_gdrive(file_id="1--s3u-BiIeoluB_ji97YE5cH13Se3dum", dest_dir="saved_models"):
80
+ os.makedirs(dest_dir, exist_ok=True)
81
+ zip_path = os.path.join(dest_dir, "models.zip")
82
+ # If zip already exists and passes a basic check, skip download
83
+ if os.path.exists(zip_path) and os.path.getsize(zip_path) > 100:
84
+ st.info("Model zip file already exists; skipping download.")
85
+ else:
86
+ st.info("Downloading models from Google Drive...")
87
+ url = f"https://drive.google.com/u/0/uc?id={file_id}&export=download&confirm=t"
88
+ output = gdown.download(url, zip_path, quiet=False, fuzzy=True)
89
+ if output is None:
90
+ raise Exception("Download failed - gdown returned None")
91
+ st.write(f"Downloaded file size: {os.path.getsize(zip_path) / (1024*1024):.2f} MB")
92
+ # Extract only if necessary
93
+ extracted_dir = os.path.join(dest_dir, "extracted")
94
+ if not os.path.exists(extracted_dir):
95
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
96
+ zip_ref.extractall(extracted_dir)
97
+ st.write("Files extracted successfully")
98
+ return extracted_dir
99
+
100
+ @st.cache_resource
101
+ def load_saved_models():
102
+ """Load all saved models from Google Drive"""
103
+ models = {}
104
+ scalers = {}
105
+ pcas = {}
106
+
107
+ # Download models to temporary directory
108
+ temp_dir = download_models_from_gdrive()
109
+ if not temp_dir:
110
+ raise Exception("Failed to download models from Google Drive")
111
+
112
+ try:
113
+ # Load models from temporary directory
114
+ models_dir = os.path.join(temp_dir, "saved_models")
115
+
116
+ for filename in os.listdir(models_dir):
117
+ if filename.endswith("_model.pth"):
118
+ # Extract symptom name and handle special characters
119
+ symptom = filename.replace("_model.pth", "")
120
+ model_path = os.path.join(models_dir, filename)
121
+ scaler_path = os.path.join(models_dir, f"{symptom}_scaler.joblib")
122
+ pca_path = os.path.join(models_dir, f"{symptom}_pca.joblib")
123
+
124
+ # Initialize and load model
125
+ model = MicrobiomeNet(input_size=1024, hidden_size=128, output_size=2)
126
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
127
+ model.eval()
128
+
129
+ # Load scaler and PCA
130
+ scaler = joblib.load(scaler_path)
131
+ pca = joblib.load(pca_path)
132
+
133
+ models[symptom] = model
134
+ scalers[symptom] = scaler
135
+ pcas[symptom] = pca
136
+
137
+ st.write(f"Loaded {len(models)} models successfully")
138
+ return models, scalers, pcas
139
+
140
+ except Exception as e:
141
+ st.error(f"Error in load_saved_models: {str(e)}")
142
+ raise
143
+ # finally:
144
+ # # Clean up temporary directory
145
+ # shutil.rmtree(temp_dir)
146
+
147
+ def process_species_data(file):
148
+ """Process the uploaded species TSV file"""
149
+ df = pd.read_csv(file, sep="\t")
150
+
151
+ # Extract abundance and species columns
152
+ print(df.columns)
153
+ print("\n\n")
154
+ print(df.head())
155
+ print("\n\n")
156
+
157
+ abundance_data = df[['%_Abundance', 'Species_Name']]
158
+
159
+ # Pivot the data to get species as columns
160
+ pivoted_data = abundance_data.pivot_table(
161
+ index=None,
162
+ values='%_Abundance',
163
+ columns='Species_Name',
164
+ aggfunc='sum'
165
+ ).fillna(0)
166
+
167
+ return pivoted_data
168
+
169
+ def predict_symptoms(data, models, scalers, pcas):
170
+ """Make predictions for each symptom"""
171
+ predictions = {}
172
+
173
+ for symptom, model in models.items():
174
+ try:
175
+ # Get the feature names from the scaler
176
+ scaler_features = scalers[symptom].get_feature_names_out()
177
+
178
+ # Create a DataFrame with zeros for all scaler features
179
+ prediction_data = pd.DataFrame(0, index=[0], columns=scaler_features)
180
+
181
+ # Fill in the available species data
182
+ common_species = data.columns.intersection(scaler_features)
183
+ prediction_data[common_species] = data[common_species]
184
+
185
+ # Scale the data
186
+ scaled_data = scalers[symptom].transform(prediction_data)
187
+
188
+ # Apply PCA transformation
189
+ pca_data = pcas[symptom].transform(scaled_data)
190
+
191
+ # Convert to tensor
192
+ input_tensor = torch.FloatTensor(pca_data)
193
+
194
+ # Make prediction
195
+ with torch.no_grad():
196
+ output = model(input_tensor)
197
+ prediction = torch.sigmoid(output).numpy()
198
+
199
+ predictions[symptom] = prediction[0][0]
200
+
201
+ except Exception as e:
202
+ st.error(f"Error predicting {symptom}: {str(e)}")
203
+ continue
204
+
205
+ return predictions
206
+
207
+ def get_friendly_symptom_name(symptom):
208
+ """Convert the long symptom names to friendly display names"""
209
+ # Dictionary mapping original names to friendly names
210
+ name_mapping = {
211
+ "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Bloating": "Bloating Severity",
212
+ "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Acidity_Burning": "Acidity Severity",
213
+ "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Constipation": "Constipation Severity",
214
+ "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Loose_Motion_Diarrhea": "Diarrhea Severity",
215
+ "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Flatulence_Gas_Fart": "Gas Severity",
216
+ "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Burping": "Burping Severity",
217
+ "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Acidity": "Acidity Frequency",
218
+ "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Bloating": "Bloating Frequency",
219
+ "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Burping": "Burping Frequency",
220
+ "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Constipation": "Constipation Frequency",
221
+ "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Flatulence_Gas_Fart": "Gas Frequency"
222
+ }
223
+ return name_mapping.get(symptom, symptom)
224
+
225
+ def main():
226
+ st.title("🦠 Microbiome Symptom Predictor")
227
+
228
+ # Load saved models
229
+ try:
230
+ models, scalers, pcas = load_saved_models()
231
+ st.success("Models loaded successfully!")
232
+
233
+ # Display some model info
234
+ sample_scaler = next(iter(scalers.values()))
235
+ n_features = len(sample_scaler.get_feature_names_out())
236
+ st.info(f"Models expect {n_features} species features and will use PCA to reduce to 1024 dimensions.")
237
+
238
+ except Exception as e:
239
+ st.error(f"Error loading models: {str(e)}")
240
+ return
241
+
242
+ # File upload
243
+ st.header("Upload Species Data")
244
+ uploaded_file = st.file_uploader(
245
+ "Upload your species abundance TSV file",
246
+ type=['tsv'],
247
+ help="Upload a TSV file containing species abundance data"
248
+ )
249
+
250
+ if uploaded_file is not None:
251
+ try:
252
+ # Process the uploaded file
253
+ species_data = process_species_data(uploaded_file)
254
+
255
+ # Show some data info
256
+ st.info(f"Processed {len(species_data.columns)} species from your data.")
257
+
258
+ # Make predictions
259
+ predictions = predict_symptoms(species_data, models, scalers, pcas)
260
+
261
+ if predictions:
262
+ # Display results
263
+ st.header("Prediction Results")
264
+
265
+ # Create two columns
266
+ col1, col2 = st.columns(2)
267
+
268
+ with col1:
269
+ st.subheader("Prediction Scores")
270
+ # Create a DataFrame for the predictions with friendly names
271
+ pred_df = pd.DataFrame({
272
+ 'Symptom': [get_friendly_symptom_name(k) for k in predictions.keys()],
273
+ 'Probability': list(predictions.values())
274
+ })
275
+
276
+ # Display as table
277
+ st.dataframe(pred_df.style.format({'Probability': '{:.2%}'}))
278
+
279
+ with col2:
280
+ st.subheader("Visualization")
281
+ # Create bar plot with friendly names
282
+ fig = go.Figure(data=[
283
+ go.Bar(
284
+ x=[get_friendly_symptom_name(k) for k in predictions.keys()],
285
+ y=list(predictions.values()),
286
+ text=[f"{v:.1%}" for v in predictions.values()],
287
+ textposition='auto',
288
+ )
289
+ ])
290
+
291
+ fig.update_layout(
292
+ title="Symptom Prediction Probabilities",
293
+ xaxis_title="Symptoms",
294
+ yaxis_title="Probability",
295
+ yaxis_range=[0, 1],
296
+ template="plotly_white",
297
+ paper_bgcolor='rgba(0,0,0,0)'
298
+ )
299
+
300
+ # Rotate x-axis labels for better readability
301
+ fig.update_layout(
302
+ xaxis_tickangle=-45,
303
+ margin=dict(b=100) # Add bottom margin for rotated labels
304
+ )
305
+
306
+ st.plotly_chart(fig, use_container_width=True)
307
+
308
+ except Exception as e:
309
+ st.error(f"An error occurred: {str(e)}")
310
+ st.write("Error details:", str(e))
311
+ st.write("Please ensure your TSV file:")
312
+ st.write("1. Contains '%_Abundance' and 'Species_Name' columns")
313
+ st.write("2. Is properly formatted")
314
+ st.write("3. Contains species that match the training data")
315
+
316
+ # Add information about the expected format
317
+ with st.expander("ℹ️ Input Format Information"):
318
+ st.write("""
319
+ Your TSV file should contain the following columns:
320
+ - %_Abundance: Numerical values representing species abundance
321
+ - Species_Name: Names of the species
322
+ - Tax_ID: Taxonomy IDs (optional)
323
+ - Taxonomy: Full taxonomy information (optional)
324
+
325
+ Only the abundance and species name columns will be used for prediction.
326
+ """)
327
+
328
+ if __name__ == "__main__":
329
+ main()