Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import joblib | |
| import os | |
| import gdown | |
| import tempfile | |
| import shutil | |
| import requests | |
| import zipfile | |
| from tqdm import tqdm | |
| # Set page config | |
| st.set_page_config( | |
| page_title="Microbiome Symptom Predictor", | |
| page_icon="🦠", | |
| layout="wide" | |
| ) | |
| class MicrobiomeNet(nn.Module): | |
| def __init__(self, input_size=1024, hidden_size=128, output_size=2): | |
| super(MicrobiomeNet, self).__init__() | |
| # Feature attention network | |
| self.feature_attention = nn.Sequential( | |
| nn.Linear(input_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size, 1) | |
| ) | |
| # Abundance processing network | |
| self.abundance_network = nn.Sequential( | |
| nn.Linear(input_size, hidden_size), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(hidden_size), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden_size, hidden_size) | |
| ) | |
| # Interaction processing network | |
| self.interaction_network = nn.Sequential( | |
| nn.Linear(input_size, hidden_size), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(hidden_size), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden_size, hidden_size) | |
| ) | |
| # Final layers | |
| self.final_layers = nn.Sequential( | |
| nn.Linear(hidden_size * 2, hidden_size), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(hidden_size), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden_size, output_size) | |
| ) | |
| def forward(self, x): | |
| # Apply feature attention | |
| attention = torch.sigmoid(self.feature_attention(x)) | |
| x_attended = x * attention | |
| # Process through parallel networks | |
| abundance_features = self.abundance_network(x_attended) | |
| interaction_features = self.interaction_network(x) | |
| # Combine features | |
| combined = torch.cat([abundance_features, interaction_features], dim=1) | |
| # Final processing | |
| output = self.final_layers(combined) | |
| return output | |
| def download_models_from_gdrive(file_id="1--s3u-BiIeoluB_ji97YE5cH13Se3dum", dest_dir="saved_models"): | |
| os.makedirs(dest_dir, exist_ok=True) | |
| zip_path = os.path.join(dest_dir, "models.zip") | |
| # If zip already exists and passes a basic check, skip download | |
| if os.path.exists(zip_path) and os.path.getsize(zip_path) > 100: | |
| st.info("Model zip file already exists; skipping download.") | |
| else: | |
| st.info("Downloading models from Google Drive...") | |
| url = f"https://drive.google.com/u/0/uc?id={file_id}&export=download&confirm=t" | |
| output = gdown.download(url, zip_path, quiet=False, fuzzy=True) | |
| if output is None: | |
| raise Exception("Download failed - gdown returned None") | |
| st.write(f"Downloaded file size: {os.path.getsize(zip_path) / (1024*1024):.2f} MB") | |
| # Extract only if necessary | |
| extracted_dir = os.path.join(dest_dir, "extracted") | |
| if not os.path.exists(extracted_dir): | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(extracted_dir) | |
| st.write("Files extracted successfully") | |
| return extracted_dir | |
| def load_saved_models(): | |
| """Load all saved models from Google Drive""" | |
| models = {} | |
| scalers = {} | |
| pcas = {} | |
| # Download models to temporary directory | |
| temp_dir = download_models_from_gdrive() | |
| if not temp_dir: | |
| raise Exception("Failed to download models from Google Drive") | |
| try: | |
| # Load models from temporary directory | |
| models_dir = os.path.join(temp_dir, "saved_models") | |
| for filename in os.listdir(models_dir): | |
| if filename.endswith("_model.pth"): | |
| # Extract symptom name and handle special characters | |
| symptom = filename.replace("_model.pth", "") | |
| model_path = os.path.join(models_dir, filename) | |
| scaler_path = os.path.join(models_dir, f"{symptom}_scaler.joblib") | |
| pca_path = os.path.join(models_dir, f"{symptom}_pca.joblib") | |
| # Initialize and load model | |
| model = MicrobiomeNet(input_size=1024, hidden_size=128, output_size=2) | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| # Load scaler and PCA | |
| scaler = joblib.load(scaler_path) | |
| pca = joblib.load(pca_path) | |
| models[symptom] = model | |
| scalers[symptom] = scaler | |
| pcas[symptom] = pca | |
| st.write(f"Loaded {len(models)} models successfully") | |
| return models, scalers, pcas | |
| except Exception as e: | |
| st.error(f"Error in load_saved_models: {str(e)}") | |
| raise | |
| # finally: | |
| # # Clean up temporary directory | |
| # shutil.rmtree(temp_dir) | |
| def process_species_data(file): | |
| """Process the uploaded species TSV file""" | |
| df = pd.read_csv(file, sep="\t") | |
| # Extract abundance and species columns | |
| print(df.columns) | |
| print("\n\n") | |
| print(df.head()) | |
| print("\n\n") | |
| abundance_data = df[['%_Abundance', 'Species_Name']] | |
| # Pivot the data to get species as columns | |
| pivoted_data = abundance_data.pivot_table( | |
| index=None, | |
| values='%_Abundance', | |
| columns='Species_Name', | |
| aggfunc='sum' | |
| ).fillna(0) | |
| return pivoted_data | |
| def predict_symptoms(data, models, scalers, pcas): | |
| """Make predictions for each symptom""" | |
| predictions = {} | |
| for symptom, model in models.items(): | |
| try: | |
| # Get the feature names from the scaler | |
| scaler_features = scalers[symptom].get_feature_names_out() | |
| # Create a DataFrame with zeros for all scaler features | |
| prediction_data = pd.DataFrame(0, index=[0], columns=scaler_features) | |
| # Fill in the available species data | |
| common_species = data.columns.intersection(scaler_features) | |
| prediction_data[common_species] = data[common_species] | |
| # Scale the data | |
| scaled_data = scalers[symptom].transform(prediction_data) | |
| # Apply PCA transformation | |
| pca_data = pcas[symptom].transform(scaled_data) | |
| # Convert to tensor | |
| input_tensor = torch.FloatTensor(pca_data) | |
| # Make prediction | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| prediction = torch.sigmoid(output).numpy() | |
| predictions[symptom] = prediction[0][0] | |
| except Exception as e: | |
| st.error(f"Error predicting {symptom}: {str(e)}") | |
| continue | |
| return predictions | |
| def get_friendly_symptom_name(symptom): | |
| """Convert the long symptom names to friendly display names""" | |
| # Dictionary mapping original names to friendly names | |
| name_mapping = { | |
| "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Bloating": "Bloating Severity", | |
| "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Acidity_Burning": "Acidity Severity", | |
| "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Constipation": "Constipation Severity", | |
| "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Loose_Motion_Diarrhea": "Diarrhea Severity", | |
| "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Flatulence_Gas_Fart": "Gas Severity", | |
| "How_much_does_these_symptoms_bother_your_daily_life_from_1-10?__(Please_respond_for_all_symptoms)_Burping": "Burping Severity", | |
| "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Acidity": "Acidity Frequency", | |
| "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Bloating": "Bloating Frequency", | |
| "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Burping": "Burping Frequency", | |
| "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Constipation": "Constipation Frequency", | |
| "How_many_days_in_a_week_do_you_generally_experience_the_following_symptoms?_(Please_respond_for_all_symptoms)_Flatulence_Gas_Fart": "Gas Frequency" | |
| } | |
| return name_mapping.get(symptom, symptom) | |
| def main(): | |
| st.title("🦠 Microbiome Symptom Predictor") | |
| # Load saved models | |
| try: | |
| models, scalers, pcas = load_saved_models() | |
| st.success("Models loaded successfully!") | |
| # Display some model info | |
| sample_scaler = next(iter(scalers.values())) | |
| n_features = len(sample_scaler.get_feature_names_out()) | |
| st.info(f"Models expect {n_features} species features and will use PCA to reduce to 1024 dimensions.") | |
| except Exception as e: | |
| st.error(f"Error loading models: {str(e)}") | |
| return | |
| # File upload | |
| st.header("Upload Species Data") | |
| uploaded_file = st.file_uploader( | |
| "Upload your species abundance TSV file", | |
| type=['tsv'], | |
| help="Upload a TSV file containing species abundance data" | |
| ) | |
| if uploaded_file is not None: | |
| try: | |
| # Process the uploaded file | |
| species_data = process_species_data(uploaded_file) | |
| # Show some data info | |
| st.info(f"Processed {len(species_data.columns)} species from your data.") | |
| # Make predictions | |
| predictions = predict_symptoms(species_data, models, scalers, pcas) | |
| if predictions: | |
| # Display results | |
| st.header("Prediction Results") | |
| # Create two columns | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Prediction Scores") | |
| # Create a DataFrame for the predictions with friendly names | |
| pred_df = pd.DataFrame({ | |
| 'Symptom': [get_friendly_symptom_name(k) for k in predictions.keys()], | |
| 'Probability': list(predictions.values()) | |
| }) | |
| # Display as table | |
| st.dataframe(pred_df.style.format({'Probability': '{:.2%}'})) | |
| with col2: | |
| st.subheader("Visualization") | |
| # Create bar plot with friendly names | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| x=[get_friendly_symptom_name(k) for k in predictions.keys()], | |
| y=list(predictions.values()), | |
| text=[f"{v:.1%}" for v in predictions.values()], | |
| textposition='auto', | |
| ) | |
| ]) | |
| fig.update_layout( | |
| title="Symptom Prediction Probabilities", | |
| xaxis_title="Symptoms", | |
| yaxis_title="Probability", | |
| yaxis_range=[0, 1], | |
| template="plotly_white", | |
| paper_bgcolor='rgba(0,0,0,0)' | |
| ) | |
| # Rotate x-axis labels for better readability | |
| fig.update_layout( | |
| xaxis_tickangle=-45, | |
| margin=dict(b=100) # Add bottom margin for rotated labels | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| st.write("Error details:", str(e)) | |
| st.write("Please ensure your TSV file:") | |
| st.write("1. Contains '%_Abundance' and 'Species_Name' columns") | |
| st.write("2. Is properly formatted") | |
| st.write("3. Contains species that match the training data") | |
| # Add information about the expected format | |
| with st.expander("ℹ️ Input Format Information"): | |
| st.write(""" | |
| Your TSV file should contain the following columns: | |
| - %_Abundance: Numerical values representing species abundance | |
| - Species_Name: Names of the species | |
| - Tax_ID: Taxonomy IDs (optional) | |
| - Taxonomy: Full taxonomy information (optional) | |
| Only the abundance and species name columns will be used for prediction. | |
| """) | |
| if __name__ == "__main__": | |
| main() | |