emvecchi commited on
Commit
46c2410
·
verified ·
1 Parent(s): 8b54a3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -23
app.py CHANGED
@@ -1,24 +1,22 @@
1
  import json
2
  import os
 
 
 
3
  from dataclasses import dataclass, field
4
  from typing import List, Optional, Dict
5
  from PIL import Image
6
 
7
  import streamlit as st
 
 
8
  from markdown import markdown
9
  st.set_page_config(layout="wide", page_title="Annotation of Simulated Patients", initial_sidebar_state="collapsed")
10
-
11
  import re, textwrap, html as py_html
12
  from pathlib import Path
13
- from streamlit.components.v1 import html as st_html
14
-
15
- import numpy as np
16
- import pandas as pd
17
  from fsspec.implementations.local import LocalFileSystem
18
  from huggingface_hub import HfFileSystem
19
 
20
- import streamlit.components.v1 as components
21
-
22
  @dataclass
23
  class Field:
24
  type: str
@@ -119,7 +117,7 @@ Please indicate, in the box below, that you are at least 18 years old, have read
119
  - *I agree to participate in this research and I want to continue with the survey.*
120
  '''
121
 
122
- fields: List[Field] = [
123
  Field(name="patient", type="input_col", title=" "),
124
  Field(type="expander", title="**Session Transcription:** *(expand)*", children=[
125
  Field(name="dialogue_name", type="input_col", title=""),
@@ -140,7 +138,8 @@ fields: List[Field] = [
140
  Field(name="rupture_marker", type="rupture_markers",
141
  title="Select rupture markers noted in the session, include line numbers where rupture is found.", mandatory=False),
142
  ]),
143
-
 
144
  Field(type="container", title="#### True-To-Patient-Prompt Features", children=[
145
  Field(type="expander", title="**Patient Role Description:** *(expand)*", children=[
146
  Field(name="role_name", type="input_col", title=""),
@@ -198,6 +197,9 @@ fields: List[Field] = [
198
  Field(name="other_comments", type="text", title="Please provide any additional details or information:", mandatory=False),
199
  ]),
200
  ]
 
 
 
201
  url_conditional_fields = [
202
  Field(name="skip", type="skip_checkbox",
203
  title="*I am uncomfortable annotating this text and voluntarily skip this instance*", mandatory=False)
@@ -214,6 +216,8 @@ INPUT_FIELD_DEFAULT_VALUES = {'slider': 0,
214
  SHOW_HELP_ICON = False
215
  SHOW_VALIDATION_ERROR_MESSAGE = True
216
 
 
 
217
  ########################################################################################
218
  if filesystem == 'hf':
219
  HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
@@ -535,6 +539,7 @@ def validate_current_page(fields: List[Field], index: int) -> bool:
535
 
536
  # Function to navigate rows
537
  def navigate(index_change):
 
538
  st.session_state.current_index += index_change
539
  # only works consistently if done before rerun
540
  js = '''
@@ -640,8 +645,6 @@ def show_field(f: Field, index: int, data_collected):
640
  if data_collected else INPUT_FIELD_DEFAULT_VALUES[f.type]
641
  )
642
  value = _ensure_key(key, default_val)
643
- #value = st.session_state[key] if key in st.session_state else \
644
- # (data_collected[f.name] if data_collected else INPUT_FIELD_DEFAULT_VALUES[f.type])
645
  if not SHOW_HELP_ICON:
646
  f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
647
 
@@ -765,14 +768,33 @@ def show_fields(fields: List[Field]):
765
  # mark that the page has been rendered at least once (if you still use this elsewhere)
766
  st.session_state.form_displayed = st.session_state.current_index
767
 
768
- def prep_and_save_data(index, skip_sample):
 
 
 
 
 
 
 
 
 
 
 
 
769
  payload = {
770
  'user_id': st.session_state.user_id,
771
  'index': st.session_state.current_index,
772
  **(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if 0 <= index < len(st.session_state.data) else {}),
773
  **{k: st.session_state[k + str(index)] for k in st.session_state.data_inputs_keys},
774
- 'skip': skip_sample
 
775
  }
 
 
 
 
 
 
776
  # normalize rupture markers -> always write 8 slots
777
  count_key = f"rupture_count_{index}"
778
  count = st.session_state.get(count_key, 1)
@@ -868,10 +890,14 @@ if 'current_index' not in st.session_state:
868
  st.session_state.current_index = start_index+1
869
  st.session_state.form_displayed = -2
870
 
 
 
 
871
  if get_param_from_url('show_extra_fields'):
872
- fields += url_conditional_fields
873
  else:
874
- fields += url_conditional_fields
 
875
 
876
  def add_validated_submit(fields, message):
877
  st.session_state.form_displayed = st.session_state.current_index
@@ -907,17 +933,29 @@ elif st.session_state.current_index == -1:
907
  add_validated_submit([st.session_state.user_id], "Please enter a valid user ID")
908
 
909
  elif st.session_state.current_index < len(st.session_state.data):
910
- show_fields(fields)
 
 
 
911
 
912
  # Action buttons
913
  c1, c2, c3 = st.columns([1,1,6])
914
  with c1:
915
- if st.button("Submit"):
916
- if validate_current_page(fields, st.session_state.current_index):
 
 
 
917
  with st.spinner("saving"):
918
- prep_and_save_data(st.session_state.current_index, ('skip' in st.session_state and st.session_state['skip']))
919
- st.success("Saved!")
920
- navigate(1)
 
 
 
 
 
 
921
 
922
  elif st.session_state.current_index == len(st.session_state.data):
923
  st.write(f"**Thank you for taking part in this study!** \n ")
@@ -927,10 +965,16 @@ elif st.session_state.current_index == len(st.session_state.data):
927
  # Navigation buttons
928
  if 0 < st.session_state.current_index < len(st.session_state.data):
929
  if st.button("Previous"):
930
- navigate(-1)
 
 
 
 
 
 
931
 
932
  if 0 <= st.session_state.current_index < len(st.session_state.data):
933
- st.write(f"Page {st.session_state.current_index + 1} out of {len(st.session_state.data)}")
934
 
935
  st.markdown(
936
  """<style>
 
1
  import json
2
  import os
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
  from dataclasses import dataclass, field
7
  from typing import List, Optional, Dict
8
  from PIL import Image
9
 
10
  import streamlit as st
11
+ import streamlit.components.v1 as components
12
+ from streamlit.components.v1 import html as st_html
13
  from markdown import markdown
14
  st.set_page_config(layout="wide", page_title="Annotation of Simulated Patients", initial_sidebar_state="collapsed")
 
15
  import re, textwrap, html as py_html
16
  from pathlib import Path
 
 
 
 
17
  from fsspec.implementations.local import LocalFileSystem
18
  from huggingface_hub import HfFileSystem
19
 
 
 
20
  @dataclass
21
  class Field:
22
  type: str
 
117
  - *I agree to participate in this research and I want to continue with the survey.*
118
  '''
119
 
120
+ fields0: List[Field] = [
121
  Field(name="patient", type="input_col", title=" "),
122
  Field(type="expander", title="**Session Transcription:** *(expand)*", children=[
123
  Field(name="dialogue_name", type="input_col", title=""),
 
138
  Field(name="rupture_marker", type="rupture_markers",
139
  title="Select rupture markers noted in the session, include line numbers where rupture is found.", mandatory=False),
140
  ]),
141
+ ]
142
+ fields1: List[Field] = [
143
  Field(type="container", title="#### True-To-Patient-Prompt Features", children=[
144
  Field(type="expander", title="**Patient Role Description:** *(expand)*", children=[
145
  Field(name="role_name", type="input_col", title=""),
 
197
  Field(name="other_comments", type="text", title="Please provide any additional details or information:", mandatory=False),
198
  ]),
199
  ]
200
+
201
+ STEPS: List[List[Field]] = [fields0, fields1]
202
+
203
  url_conditional_fields = [
204
  Field(name="skip", type="skip_checkbox",
205
  title="*I am uncomfortable annotating this text and voluntarily skip this instance*", mandatory=False)
 
216
  SHOW_HELP_ICON = False
217
  SHOW_VALIDATION_ERROR_MESSAGE = True
218
 
219
+
220
+
221
  ########################################################################################
222
  if filesystem == 'hf':
223
  HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
 
539
 
540
  # Function to navigate rows
541
  def navigate(index_change):
542
+ st.session_state.step = 0
543
  st.session_state.current_index += index_change
544
  # only works consistently if done before rerun
545
  js = '''
 
645
  if data_collected else INPUT_FIELD_DEFAULT_VALUES[f.type]
646
  )
647
  value = _ensure_key(key, default_val)
 
 
648
  if not SHOW_HELP_ICON:
649
  f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
650
 
 
768
  # mark that the page has been rendered at least once (if you still use this elsewhere)
769
  st.session_state.form_displayed = st.session_state.current_index
770
 
771
+ def iter_all_input_fields():
772
+ """Yield all Field objects which are real input widgets, across all steps."""
773
+ def walk(nodes):
774
+ for f in nodes:
775
+ if f.children:
776
+ walk(f.children)
777
+ elif f.type in INPUT_FIELD_DEFAULT_VALUES:
778
+ yield f
779
+ for step_fields in STEPS:
780
+ yield from walk(step_fields)
781
+
782
+
783
+ def prep_and_save_data(index, skip_sample, completed: bool):
784
  payload = {
785
  'user_id': st.session_state.user_id,
786
  'index': st.session_state.current_index,
787
  **(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if 0 <= index < len(st.session_state.data) else {}),
788
  **{k: st.session_state[k + str(index)] for k in st.session_state.data_inputs_keys},
789
+ 'skip': skip_sample,
790
+ 'completed': completed,
791
  }
792
+
793
+ for f in iter_all_input_fields():
794
+ key = f.name + str(index)
795
+ val = st.session_state.get(key, INPUT_FIELD_DEFAULT_VALUES[f.type])
796
+ payload[f.name] = val
797
+
798
  # normalize rupture markers -> always write 8 slots
799
  count_key = f"rupture_count_{index}"
800
  count = st.session_state.get(count_key, 1)
 
890
  st.session_state.current_index = start_index+1
891
  st.session_state.form_displayed = -2
892
 
893
+ if 'step' not in st.session_state:
894
+ st.session_state.step = 0
895
+
896
  if get_param_from_url('show_extra_fields'):
897
+ fields1 += url_conditional_fields
898
  else:
899
+ fields1 += url_conditional_fields
900
+
901
 
902
  def add_validated_submit(fields, message):
903
  st.session_state.form_displayed = st.session_state.current_index
 
933
  add_validated_submit([st.session_state.user_id], "Please enter a valid user ID")
934
 
935
  elif st.session_state.current_index < len(st.session_state.data):
936
+ step = st.session_state.step
937
+ total_steps = len(STEPS)
938
+
939
+ show_fields(STEPS[step])
940
 
941
  # Action buttons
942
  c1, c2, c3 = st.columns([1,1,6])
943
  with c1:
944
+ label = "Next" if step < total_steps - 1 else "Submit & next session"
945
+ if st.button(label):
946
+ if validate_current_page(STEPS[step], st.session_state.current_index):
947
+ is_last_page = (step == total_steps - 1)
948
+
949
  with st.spinner("saving"):
950
+ prep_and_save_data(st.session_state.current_index,
951
+ ('skip' in st.session_state and st.session_state['skip']),
952
+ completed=is_last_page)
953
+ if is_last_page:
954
+ st.success("Saved!")
955
+ navigate(1)
956
+ else:
957
+ st.session_state.step += 1
958
+ st.rerun()
959
 
960
  elif st.session_state.current_index == len(st.session_state.data):
961
  st.write(f"**Thank you for taking part in this study!** \n ")
 
965
  # Navigation buttons
966
  if 0 < st.session_state.current_index < len(st.session_state.data):
967
  if st.button("Previous"):
968
+ if step > 0:
969
+ st.session_state.step -= 1
970
+ st.rerun()
971
+ else:
972
+ st.session_state.current_index -= 1
973
+ st.session_state.step = total_steps - 1
974
+ st.rerun()
975
 
976
  if 0 <= st.session_state.current_index < len(st.session_state.data):
977
+ st.write(f"Session {st.session_state.current_index + 1} out of {len(st.session_state.data)}")
978
 
979
  st.markdown(
980
  """<style>