emvecchi commited on
Commit
e495b8f
·
verified ·
1 Parent(s): 3111cc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -1
app.py CHANGED
@@ -391,6 +391,61 @@ def render_rupture_markers_widget(index: int, choices: list[str]):
391
 
392
  st.caption(f"{st.session_state[count_key]}/{MAX_RUPTURE_MARKERS} markers")
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  #################################### Streamlit App ####################################
395
 
396
  # Function to navigate rows
@@ -429,6 +484,7 @@ def navigate(index_change):
429
  )
430
  st.rerun()
431
 
 
432
  def show_field(f: Field, index: int, data_collected):
433
  if f.type not in INPUT_FIELD_DEFAULT_VALUES.keys():
434
  st.session_state.following_mandatory = False
@@ -559,7 +615,7 @@ def show_field(f: Field, index: int, data_collected):
559
  st.session_state.unacceptable_response = False
560
  st.error(f"Mandatory field")
561
 
562
- '''
563
  def show_fields(fields: List[Field]):
564
  st.session_state.valid = True
565
  index = st.session_state.current_index
@@ -587,6 +643,125 @@ def show_fields(fields: List[Field]):
587
 
588
  st.session_state.form_displayed = st.session_state.current_index
589
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  def show_fields(fields: List[Field]):
592
  index = st.session_state.current_index
 
391
 
392
  st.caption(f"{st.session_state[count_key]}/{MAX_RUPTURE_MARKERS} markers")
393
 
394
+ def _get_value_for_field(f: Field, index: int):
395
+ """Return the session value for input fields or None."""
396
+ if f.type in INPUT_FIELD_DEFAULT_VALUES:
397
+ key = f.name + str(index)
398
+ return st.session_state.get(key, INPUT_FIELD_DEFAULT_VALUES[f.type])
399
+ return None
400
+
401
+ def _is_default_value(f: Field, val):
402
+ """Check if current value equals the default for this widget type."""
403
+ return val == INPUT_FIELD_DEFAULT_VALUES.get(f.type)
404
+
405
+ def validate_current_page(fields: List[Field], index: int) -> bool:
406
+ """
407
+ Walk the field tree and verify mandatory inputs are filled.
408
+ Honors `following_mandatory_values`: if a field's value is in that list,
409
+ subsequent siblings become mandatory.
410
+ """
411
+ ok = True
412
+
413
+ def walk(nodes: List[Field], following_required: bool = False):
414
+ nonlocal ok
415
+ # local flag that can be turned on by a node for its following siblings
416
+ req_for_following = following_required
417
+
418
+ for f in nodes:
419
+ if f.children:
420
+ # containers/expanders: recurse; pass the current rule downward
421
+ walk(f.children, req_for_following)
422
+ continue
423
+
424
+ if f.type in INPUT_FIELD_DEFAULT_VALUES:
425
+ val = _get_value_for_field(f, index)
426
+
427
+ # decide if THIS field is required
428
+ required = bool(f.mandatory or req_for_following)
429
+
430
+ # if required and default → error
431
+ if required and _is_default_value(f, val):
432
+ ok = False
433
+
434
+ # update the "following required" rule for siblings
435
+ if f.following_mandatory_values:
436
+ try:
437
+ if val in f.following_mandatory_values:
438
+ req_for_following = True
439
+ except Exception:
440
+ pass
441
+
442
+ # structural types (markdown/input_col/etc.) don’t affect validation
443
+
444
+ walk(fields, False)
445
+
446
+ if not ok:
447
+ st.error("Please fill in all mandatory fields.")
448
+ return ok
449
  #################################### Streamlit App ####################################
450
 
451
  # Function to navigate rows
 
484
  )
485
  st.rerun()
486
 
487
+ '''
488
  def show_field(f: Field, index: int, data_collected):
489
  if f.type not in INPUT_FIELD_DEFAULT_VALUES.keys():
490
  st.session_state.following_mandatory = False
 
615
  st.session_state.unacceptable_response = False
616
  st.error(f"Mandatory field")
617
 
618
+
619
  def show_fields(fields: List[Field]):
620
  st.session_state.valid = True
621
  index = st.session_state.current_index
 
643
 
644
  st.session_state.form_displayed = st.session_state.current_index
645
  '''
646
+ def show_field(f: Field, index: int, data_collected):
647
+ """Render a Field tree. No form logic; everything runs outside forms."""
648
+ # Non-input widgets (custom structural types)
649
+ if f.type not in INPUT_FIELD_DEFAULT_VALUES.keys():
650
+ match f.type:
651
+ case 'input_col':
652
+ # Safely read a column from the current CSV row (or None if missing)
653
+ value = st.session_state.data.iloc[index][f.name] if (
654
+ f.name and f.name in st.session_state.data.columns
655
+ ) else None
656
+
657
+ # Render based on field name
658
+ if f.name == 'image_name' and value:
659
+ display_image(os.path.join(input_repo_path, 'images', value))
660
+ elif f.name == 'dialogue_name' and value:
661
+ render_dialogue(
662
+ os.path.join(input_repo_path, 'dialogues', value),
663
+ width_chars=115, height_px=520, show_border=False
664
+ )
665
+ elif f.name == 'patient' and value:
666
+ st.markdown(f"## Patient:  {value}")
667
+ elif value not in (None, np.nan, ""):
668
+ # generic fallback
669
+ st.write(f.title)
670
+ st.write(value)
671
+
672
+ case 'markdown':
673
+ # If a file path is provided, load & render it; otherwise render the title string
674
+ path = f.other_params.get("path") if f.other_params else None
675
+ if path:
676
+ content = load_text(os.path.join(input_repo_path, path))
677
+ st.markdown(content, unsafe_allow_html=True)
678
+ else:
679
+ st.markdown(f.title)
680
+
681
+ case 'expander':
682
+ # Markdown header above; expander label is plain text
683
+ st.markdown(f.title)
684
+ with st.expander(""):
685
+ for child in (f.children or []):
686
+ show_field(child, index, data_collected)
687
+
688
+ case 'container':
689
+ with st.container(border=True):
690
+ st.markdown(f.title)
691
+ for child in (f.children or []):
692
+ show_field(child, index, data_collected)
693
+
694
+ case 'skip_checkbox':
695
+ # outside forms this is fine
696
+ st.checkbox(f.title, key=f.name, value=False)
697
+
698
+ case 'rupture_markers':
699
+ # dynamic widget (outside forms)
700
+ render_rupture_markers_widget(index, rupture_choices)
701
+
702
+ return # done with non-input types
703
+
704
+ # Input widgets (saved via data_inputs_keys)
705
+ key = f.name + str(index)
706
+ st.session_state.data_inputs_keys.append(f.name)
707
+
708
+ # Initial value: session_state > saved JSON > default
709
+ if key in st.session_state:
710
+ value = st.session_state[key]
711
+ elif data_collected and f.name in data_collected:
712
+ value = data_collected[f.name]
713
+ else:
714
+ value = INPUT_FIELD_DEFAULT_VALUES[f.type]
715
+
716
+ # Render inputs
717
+ match f.type:
718
+ case 'checkbox':
719
+ st.checkbox(f.title, key=key, value=value, help=f.help)
720
+
721
+ case 'radio':
722
+ labels = f.other_params.get('labels') if f.other_params and f.other_params.get('labels') else default_labels
723
+ st.radio(f.title,
724
+ options=range(len(labels)),
725
+ format_func=lambda x: labels[x],
726
+ key=key,
727
+ index=value if value is not None else 0,
728
+ help=f.help, horizontal=False)
729
+
730
+ case 'slider':
731
+ st.slider(f.title, min_value=0, max_value=6, step=1, key=key, value=value, help=f.help)
732
+
733
+ case 'select_slider':
734
+ labels = f.other_params.get('labels') if f.other_params and f.other_params.get('labels') else default_labels
735
+ st.select_slider(f.title,
736
+ options=[0, 20, 40, 60, 80, 100],
737
+ format_func=lambda x: labels[x // 20],
738
+ key=key, value=value, help=f.help)
739
+
740
+ case 'multiselect':
741
+ choices = f.other_params.get('choices') if f.other_params and f.other_params.get('choices') else default_choices
742
+ st.multiselect(f.title, options=choices, format_func=lambda x: x,
743
+ key=key, max_selections=3, default=value, help=f.help)
744
+
745
+ case 'likert_radio':
746
+ labels = f.other_params.get('labels') if f.other_params and f.other_params.get('labels') else default_labels
747
+ st.radio(f.title, options=[0, 1, 2, 3, 4],
748
+ format_func=lambda x: labels[x],
749
+ key=key, index=value if value is not None else 0,
750
+ help=f.help, horizontal=True)
751
+
752
+ case 'y_n_radio':
753
+ labels = f.other_params.get('labels') if f.other_params and f.other_params.get('labels') else yes_no_labels
754
+ st.radio(f.title, options=[0, 1],
755
+ format_func=lambda x: labels[x],
756
+ key=key, index=value if value is not None else 0,
757
+ help=f.help, horizontal=True)
758
+
759
+ case 'text':
760
+ st.text_input(f.title, key=key, value=value if value is not None else "", max_chars=None)
761
+
762
+ case 'textarea':
763
+ st.text_area(f.title, key=key, value=value if value is not None else "", max_chars=None)
764
+
765
 
766
  def show_fields(fields: List[Field]):
767
  index = st.session_state.current_index