Spaces:
Runtime error
Runtime error
| import altair as alt | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import json | |
| import random | |
| import pickle | |
| # set page configuration to wide mode | |
| st.set_page_config(layout="wide") | |
| # section 1 | |
| st.markdown("#### About") | |
| st.markdown("Testing") | |
| st.markdown("Links: \n* GitHub: [https://github.com/vkola-lab/ncomms2025](https://github.com/vkola-lab/ncomms2025)\n* Our lab: [https://vkola-lab.github.io/](https://vkola-lab.github.io/)") | |
| # section 2 | |
| st.markdown("#### Demo") | |
| st.markdown("This Hugging Face Space is published for demonstration purposes. Users can input over 300 clinical entries to assess probabilities of amyloid and tau PET positivity. However, due to the computational power limitations of the Hugging Face free tier, imaging features and Shapley values analysis are not supported. For the full implementation, please refer to our GitHub repository.") | |
| st.markdown("To use the demo:\n* Provide input features in the form below. Feature missing is allowed.\n* Click the \"**RANDOM EXAMPLE**\" button to populate the form with a randomly selected datapoint.\n* Use the \"**PREDICT**\" button to submit all input features for assessment, then the predictions will be posted in a table.") | |
| # load model | |
| def load_model(): | |
| import adrd | |
| ckpt_path = 'src/model_stage_1.ckpt' | |
| model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') | |
| return model | |
| model = load_model() | |
| def predict_proba(data_dict): | |
| pred_dict = model.predict_proba([data_dict])[1][0] | |
| return pred_dict | |
| # load meta data csv | |
| file_path = 'src/data/input_meta_info.csv' | |
| input_meta_info = pd.read_csv(file_path) | |
| # load NACC testing data | |
| from data.dataset_csv_new import CSVDataset | |
| dat_tst = CSVDataset( | |
| dat_file = "src/data/test_public.csv", | |
| cnf_file = "src/data/input_meta_info.csv" | |
| ) | |
| def get_random_example(): | |
| idx = random.randint(0, len(dat_tst) - 1) | |
| random_case = dat_tst[idx][0] | |
| return random_case | |
| # Get random example features if the button is clicked | |
| if 'random_example' not in st.session_state: | |
| st.session_state.random_example = None | |
| st.markdown('---') | |
| cols = st.columns(3) | |
| with cols[1]: | |
| random_example_button = st.button("RANDOM EXAMPLE", use_container_width=True) | |
| if random_example_button: | |
| st.session_state.random_example = get_random_example() | |
| st.rerun() | |
| random_example = st.session_state.random_example | |
| def create_input(df, i): | |
| row = df.iloc[i] | |
| name = row['Name'] | |
| description = row['Description'] | |
| # dirty work, inspect keys and values | |
| values = row['Values'] | |
| values = values.replace('\'', '\"') | |
| values = values.replace('\"0\": nan, ', '') | |
| values = json.loads(values) | |
| for k, v in list(values.items()): | |
| if v == 'Unknown': | |
| values.pop(k) | |
| elif k in ('9', '99', '999'): | |
| values.pop(k) | |
| # get default value from random example if available | |
| default_value = random_example[name] if random_example and name in random_example else None | |
| if type(default_value) is float: | |
| default_value = int(default_value) | |
| # Determine the type of widget based on values | |
| if 'range' in values: | |
| if ' - ' in values['range']: | |
| min_value, max_value = map(float, values['range'].split(' - ')) | |
| min_value, max_value = int(min_value), int(max_value) | |
| if default_value is not None: | |
| if default_value > max_value or default_value < min_value: | |
| default_value = None | |
| st.number_input(description, key=name, min_value=min_value, max_value=max_value, value=default_value, placeholder=values['range']) | |
| else: | |
| min_value = int(values['range'].replace('>= ', '')) | |
| if default_value is not None: | |
| if default_value < min_value or default_value == 8888: | |
| default_value = None | |
| st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range']) | |
| else: | |
| values = {int(k): v for k, v in values.items()} | |
| if default_value in values: | |
| default_index = list(values.keys()).index(default_value) | |
| else: | |
| default_index = None | |
| st.selectbox( | |
| description, | |
| options = values.keys(), | |
| key = name, | |
| index = default_index, | |
| format_func=lambda x: values[x] | |
| ) | |
| # create form | |
| with st.form("dynamic_form"): | |
| sections = input_meta_info['Section'].unique() | |
| for section in sections: | |
| with st.container(): | |
| st.markdown(f"##### {section}") | |
| sub_df = input_meta_info[input_meta_info['Section'] == section] | |
| cols = st.columns(3) | |
| with cols[0]: | |
| for i in range(0, len(sub_df), 3): | |
| create_input(sub_df, i) | |
| with cols[1]: | |
| for i in range(1, len(sub_df), 3): | |
| create_input(sub_df, i) | |
| with cols[2]: | |
| for i in range(2, len(sub_df), 3): | |
| create_input(sub_df, i) | |
| # seperate line | |
| st.markdown("---") | |
| cols = st.columns(3) | |
| with cols[1]: | |
| predict_button = st.form_submit_button("PREDICT", use_container_width=True, type='primary') | |
| # load mapping | |
| with open('src/data/nacc_variable_mapping.pkl', 'rb') as file: | |
| nacc_mapping = pickle.load(file) | |
| def convert_dictionary(original_dict, mappings): | |
| transformed_dict = {} | |
| for key, value in original_dict.items(): | |
| if key in mappings: | |
| new_key, transform_map = mappings[key] | |
| # If the value needs to be transformed | |
| if value is not None and value in transform_map: | |
| transformed_value = transform_map[value] | |
| else: | |
| transformed_value = value # Keep the original value if no transformation is needed | |
| transformed_dict[new_key] = transformed_value | |
| return transformed_dict | |
| if predict_button: | |
| # get form input | |
| names = input_meta_info['Name'].tolist() | |
| data_dict = {} | |
| for name in names: | |
| data_dict[name] = st.session_state[name] | |
| # convert | |
| data_dict = convert_dictionary(data_dict, nacc_mapping) | |
| pred_dict = predict_proba(data_dict) | |
| # change key name and value representations | |
| key_mappings = { | |
| 'amy_label': 'Amyloid PET', | |
| 'tau_label': 'Tau PET', | |
| } | |
| pred_dict = {key_mappings[k]: f"{v * 100:.2f}%" for k, v in pred_dict.items()} | |
| df = pd.DataFrame(list(pred_dict.items()), columns=['Label', 'Predicted probability']) | |
| st.table(df) |