Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,24 +1,22 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
|
|
|
|
|
|
|
|
|
| 3 |
from dataclasses import dataclass, field
|
| 4 |
from typing import List, Optional, Dict
|
| 5 |
from PIL import Image
|
| 6 |
|
| 7 |
import streamlit as st
|
|
|
|
|
|
|
| 8 |
from markdown import markdown
|
| 9 |
st.set_page_config(layout="wide", page_title="Annotation of Simulated Patients", initial_sidebar_state="collapsed")
|
| 10 |
-
|
| 11 |
import re, textwrap, html as py_html
|
| 12 |
from pathlib import Path
|
| 13 |
-
from streamlit.components.v1 import html as st_html
|
| 14 |
-
|
| 15 |
-
import numpy as np
|
| 16 |
-
import pandas as pd
|
| 17 |
from fsspec.implementations.local import LocalFileSystem
|
| 18 |
from huggingface_hub import HfFileSystem
|
| 19 |
|
| 20 |
-
import streamlit.components.v1 as components
|
| 21 |
-
|
| 22 |
@dataclass
|
| 23 |
class Field:
|
| 24 |
type: str
|
|
@@ -119,7 +117,7 @@ Please indicate, in the box below, that you are at least 18 years old, have read
|
|
| 119 |
- *I agree to participate in this research and I want to continue with the survey.*
|
| 120 |
'''
|
| 121 |
|
| 122 |
-
|
| 123 |
Field(name="patient", type="input_col", title=" "),
|
| 124 |
Field(type="expander", title="**Session Transcription:** *(expand)*", children=[
|
| 125 |
Field(name="dialogue_name", type="input_col", title=""),
|
|
@@ -140,7 +138,8 @@ fields: List[Field] = [
|
|
| 140 |
Field(name="rupture_marker", type="rupture_markers",
|
| 141 |
title="Select rupture markers noted in the session, include line numbers where rupture is found.", mandatory=False),
|
| 142 |
]),
|
| 143 |
-
|
|
|
|
| 144 |
Field(type="container", title="#### True-To-Patient-Prompt Features", children=[
|
| 145 |
Field(type="expander", title="**Patient Role Description:** *(expand)*", children=[
|
| 146 |
Field(name="role_name", type="input_col", title=""),
|
|
@@ -198,6 +197,9 @@ fields: List[Field] = [
|
|
| 198 |
Field(name="other_comments", type="text", title="Please provide any additional details or information:", mandatory=False),
|
| 199 |
]),
|
| 200 |
]
|
|
|
|
|
|
|
|
|
|
| 201 |
url_conditional_fields = [
|
| 202 |
Field(name="skip", type="skip_checkbox",
|
| 203 |
title="*I am uncomfortable annotating this text and voluntarily skip this instance*", mandatory=False)
|
|
@@ -214,6 +216,8 @@ INPUT_FIELD_DEFAULT_VALUES = {'slider': 0,
|
|
| 214 |
SHOW_HELP_ICON = False
|
| 215 |
SHOW_VALIDATION_ERROR_MESSAGE = True
|
| 216 |
|
|
|
|
|
|
|
| 217 |
########################################################################################
|
| 218 |
if filesystem == 'hf':
|
| 219 |
HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
|
|
@@ -535,6 +539,7 @@ def validate_current_page(fields: List[Field], index: int) -> bool:
|
|
| 535 |
|
| 536 |
# Function to navigate rows
|
| 537 |
def navigate(index_change):
|
|
|
|
| 538 |
st.session_state.current_index += index_change
|
| 539 |
# only works consistently if done before rerun
|
| 540 |
js = '''
|
|
@@ -640,8 +645,6 @@ def show_field(f: Field, index: int, data_collected):
|
|
| 640 |
if data_collected else INPUT_FIELD_DEFAULT_VALUES[f.type]
|
| 641 |
)
|
| 642 |
value = _ensure_key(key, default_val)
|
| 643 |
-
#value = st.session_state[key] if key in st.session_state else \
|
| 644 |
-
# (data_collected[f.name] if data_collected else INPUT_FIELD_DEFAULT_VALUES[f.type])
|
| 645 |
if not SHOW_HELP_ICON:
|
| 646 |
f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
|
| 647 |
|
|
@@ -765,14 +768,33 @@ def show_fields(fields: List[Field]):
|
|
| 765 |
# mark that the page has been rendered at least once (if you still use this elsewhere)
|
| 766 |
st.session_state.form_displayed = st.session_state.current_index
|
| 767 |
|
| 768 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
payload = {
|
| 770 |
'user_id': st.session_state.user_id,
|
| 771 |
'index': st.session_state.current_index,
|
| 772 |
**(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if 0 <= index < len(st.session_state.data) else {}),
|
| 773 |
**{k: st.session_state[k + str(index)] for k in st.session_state.data_inputs_keys},
|
| 774 |
-
'skip': skip_sample
|
|
|
|
| 775 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
# normalize rupture markers -> always write 8 slots
|
| 777 |
count_key = f"rupture_count_{index}"
|
| 778 |
count = st.session_state.get(count_key, 1)
|
|
@@ -868,10 +890,14 @@ if 'current_index' not in st.session_state:
|
|
| 868 |
st.session_state.current_index = start_index+1
|
| 869 |
st.session_state.form_displayed = -2
|
| 870 |
|
|
|
|
|
|
|
|
|
|
| 871 |
if get_param_from_url('show_extra_fields'):
|
| 872 |
-
|
| 873 |
else:
|
| 874 |
-
|
|
|
|
| 875 |
|
| 876 |
def add_validated_submit(fields, message):
|
| 877 |
st.session_state.form_displayed = st.session_state.current_index
|
|
@@ -907,17 +933,29 @@ elif st.session_state.current_index == -1:
|
|
| 907 |
add_validated_submit([st.session_state.user_id], "Please enter a valid user ID")
|
| 908 |
|
| 909 |
elif st.session_state.current_index < len(st.session_state.data):
|
| 910 |
-
|
|
|
|
|
|
|
|
|
|
| 911 |
|
| 912 |
# Action buttons
|
| 913 |
c1, c2, c3 = st.columns([1,1,6])
|
| 914 |
with c1:
|
| 915 |
-
if
|
| 916 |
-
|
|
|
|
|
|
|
|
|
|
| 917 |
with st.spinner("saving"):
|
| 918 |
-
prep_and_save_data(st.session_state.current_index,
|
| 919 |
-
|
| 920 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
|
| 922 |
elif st.session_state.current_index == len(st.session_state.data):
|
| 923 |
st.write(f"**Thank you for taking part in this study!** \n ")
|
|
@@ -927,10 +965,16 @@ elif st.session_state.current_index == len(st.session_state.data):
|
|
| 927 |
# Navigation buttons
|
| 928 |
if 0 < st.session_state.current_index < len(st.session_state.data):
|
| 929 |
if st.button("Previous"):
|
| 930 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
|
| 932 |
if 0 <= st.session_state.current_index < len(st.session_state.data):
|
| 933 |
-
st.write(f"
|
| 934 |
|
| 935 |
st.markdown(
|
| 936 |
"""<style>
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
from dataclasses import dataclass, field
|
| 7 |
from typing import List, Optional, Dict
|
| 8 |
from PIL import Image
|
| 9 |
|
| 10 |
import streamlit as st
|
| 11 |
+
import streamlit.components.v1 as components
|
| 12 |
+
from streamlit.components.v1 import html as st_html
|
| 13 |
from markdown import markdown
|
| 14 |
st.set_page_config(layout="wide", page_title="Annotation of Simulated Patients", initial_sidebar_state="collapsed")
|
|
|
|
| 15 |
import re, textwrap, html as py_html
|
| 16 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from fsspec.implementations.local import LocalFileSystem
|
| 18 |
from huggingface_hub import HfFileSystem
|
| 19 |
|
|
|
|
|
|
|
| 20 |
@dataclass
|
| 21 |
class Field:
|
| 22 |
type: str
|
|
|
|
| 117 |
- *I agree to participate in this research and I want to continue with the survey.*
|
| 118 |
'''
|
| 119 |
|
| 120 |
+
fields0: List[Field] = [
|
| 121 |
Field(name="patient", type="input_col", title=" "),
|
| 122 |
Field(type="expander", title="**Session Transcription:** *(expand)*", children=[
|
| 123 |
Field(name="dialogue_name", type="input_col", title=""),
|
|
|
|
| 138 |
Field(name="rupture_marker", type="rupture_markers",
|
| 139 |
title="Select rupture markers noted in the session, include line numbers where rupture is found.", mandatory=False),
|
| 140 |
]),
|
| 141 |
+
]
|
| 142 |
+
fields1: List[Field] = [
|
| 143 |
Field(type="container", title="#### True-To-Patient-Prompt Features", children=[
|
| 144 |
Field(type="expander", title="**Patient Role Description:** *(expand)*", children=[
|
| 145 |
Field(name="role_name", type="input_col", title=""),
|
|
|
|
| 197 |
Field(name="other_comments", type="text", title="Please provide any additional details or information:", mandatory=False),
|
| 198 |
]),
|
| 199 |
]
|
| 200 |
+
|
| 201 |
+
STEPS: List[List[Field]] = [fields0, fields1]
|
| 202 |
+
|
| 203 |
url_conditional_fields = [
|
| 204 |
Field(name="skip", type="skip_checkbox",
|
| 205 |
title="*I am uncomfortable annotating this text and voluntarily skip this instance*", mandatory=False)
|
|
|
|
| 216 |
SHOW_HELP_ICON = False
|
| 217 |
SHOW_VALIDATION_ERROR_MESSAGE = True
|
| 218 |
|
| 219 |
+
|
| 220 |
+
|
| 221 |
########################################################################################
|
| 222 |
if filesystem == 'hf':
|
| 223 |
HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
|
|
|
|
| 539 |
|
| 540 |
# Function to navigate rows
|
| 541 |
def navigate(index_change):
|
| 542 |
+
st.session_state.step = 0
|
| 543 |
st.session_state.current_index += index_change
|
| 544 |
# only works consistently if done before rerun
|
| 545 |
js = '''
|
|
|
|
| 645 |
if data_collected else INPUT_FIELD_DEFAULT_VALUES[f.type]
|
| 646 |
)
|
| 647 |
value = _ensure_key(key, default_val)
|
|
|
|
|
|
|
| 648 |
if not SHOW_HELP_ICON:
|
| 649 |
f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
|
| 650 |
|
|
|
|
| 768 |
# mark that the page has been rendered at least once (if you still use this elsewhere)
|
| 769 |
st.session_state.form_displayed = st.session_state.current_index
|
| 770 |
|
| 771 |
+
def iter_all_input_fields():
|
| 772 |
+
"""Yield all Field objects which are real input widgets, across all steps."""
|
| 773 |
+
def walk(nodes):
|
| 774 |
+
for f in nodes:
|
| 775 |
+
if f.children:
|
| 776 |
+
walk(f.children)
|
| 777 |
+
elif f.type in INPUT_FIELD_DEFAULT_VALUES:
|
| 778 |
+
yield f
|
| 779 |
+
for step_fields in STEPS:
|
| 780 |
+
yield from walk(step_fields)
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def prep_and_save_data(index, skip_sample, completed: bool):
|
| 784 |
payload = {
|
| 785 |
'user_id': st.session_state.user_id,
|
| 786 |
'index': st.session_state.current_index,
|
| 787 |
**(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if 0 <= index < len(st.session_state.data) else {}),
|
| 788 |
**{k: st.session_state[k + str(index)] for k in st.session_state.data_inputs_keys},
|
| 789 |
+
'skip': skip_sample,
|
| 790 |
+
'completed': completed,
|
| 791 |
}
|
| 792 |
+
|
| 793 |
+
for f in iter_all_input_fields():
|
| 794 |
+
key = f.name + str(index)
|
| 795 |
+
val = st.session_state.get(key, INPUT_FIELD_DEFAULT_VALUES[f.type])
|
| 796 |
+
payload[f.name] = val
|
| 797 |
+
|
| 798 |
# normalize rupture markers -> always write 8 slots
|
| 799 |
count_key = f"rupture_count_{index}"
|
| 800 |
count = st.session_state.get(count_key, 1)
|
|
|
|
| 890 |
st.session_state.current_index = start_index+1
|
| 891 |
st.session_state.form_displayed = -2
|
| 892 |
|
| 893 |
+
if 'step' not in st.session_state:
|
| 894 |
+
st.session_state.step = 0
|
| 895 |
+
|
| 896 |
if get_param_from_url('show_extra_fields'):
|
| 897 |
+
fields1 += url_conditional_fields
|
| 898 |
else:
|
| 899 |
+
fields1 += url_conditional_fields
|
| 900 |
+
|
| 901 |
|
| 902 |
def add_validated_submit(fields, message):
|
| 903 |
st.session_state.form_displayed = st.session_state.current_index
|
|
|
|
| 933 |
add_validated_submit([st.session_state.user_id], "Please enter a valid user ID")
|
| 934 |
|
| 935 |
elif st.session_state.current_index < len(st.session_state.data):
|
| 936 |
+
step = st.session_state.step
|
| 937 |
+
total_steps = len(STEPS)
|
| 938 |
+
|
| 939 |
+
show_fields(STEPS[step])
|
| 940 |
|
| 941 |
# Action buttons
|
| 942 |
c1, c2, c3 = st.columns([1,1,6])
|
| 943 |
with c1:
|
| 944 |
+
label = "Next" if step < total_steps - 1 else "Submit & next session"
|
| 945 |
+
if st.button(label):
|
| 946 |
+
if validate_current_page(STEPS[step], st.session_state.current_index):
|
| 947 |
+
is_last_page = (step == total_steps - 1)
|
| 948 |
+
|
| 949 |
with st.spinner("saving"):
|
| 950 |
+
prep_and_save_data(st.session_state.current_index,
|
| 951 |
+
('skip' in st.session_state and st.session_state['skip']),
|
| 952 |
+
completed=is_last_page)
|
| 953 |
+
if is_last_page:
|
| 954 |
+
st.success("Saved!")
|
| 955 |
+
navigate(1)
|
| 956 |
+
else:
|
| 957 |
+
st.session_state.step += 1
|
| 958 |
+
st.rerun()
|
| 959 |
|
| 960 |
elif st.session_state.current_index == len(st.session_state.data):
|
| 961 |
st.write(f"**Thank you for taking part in this study!** \n ")
|
|
|
|
| 965 |
# Navigation buttons
|
| 966 |
if 0 < st.session_state.current_index < len(st.session_state.data):
|
| 967 |
if st.button("Previous"):
|
| 968 |
+
if step > 0:
|
| 969 |
+
st.session_state.step -= 1
|
| 970 |
+
st.rerun()
|
| 971 |
+
else:
|
| 972 |
+
st.session_state.current_index -= 1
|
| 973 |
+
st.session_state.step = total_steps - 1
|
| 974 |
+
st.rerun()
|
| 975 |
|
| 976 |
if 0 <= st.session_state.current_index < len(st.session_state.data):
|
| 977 |
+
st.write(f"Session {st.session_state.current_index + 1} out of {len(st.session_state.data)}")
|
| 978 |
|
| 979 |
st.markdown(
|
| 980 |
"""<style>
|