Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Dict | |
| from PIL import Image | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| from streamlit.components.v1 import html as st_html | |
| from markdown import markdown | |
| st.set_page_config(layout="wide", page_title="Annotation of Simulated Patients", initial_sidebar_state="collapsed") | |
| import re, textwrap, html as py_html | |
| from pathlib import Path | |
| from fsspec.implementations.local import LocalFileSystem | |
| from huggingface_hub import HfFileSystem | |
| from annotation_questions import ( | |
| Field, | |
| STEPS, | |
| INPUT_FIELD_DEFAULT_VALUES, | |
| COLS_TO_SAVE, | |
| consent_text, | |
| SHOW_HELP_ICON, | |
| SHOW_VALIDATION_ERROR_MESSAGE, | |
| yes_no_labels, # if you still use these in y_n_radio | |
| default_labels # if you still use this fallback | |
| # plus any others you rely on in app.py | |
| ) | |
| ######################################################################################## | |
| # Function to get user ID from URL | |
| def get_param_from_url(param: str, default: str = ""): | |
| return st.query_params.get(param, default) | |
| #def get_param_from_url(param): | |
| # user_id = st.query_params.get(param, "") | |
| # return user_id | |
| # 'local' or 'hf'. hf is for Hugging Face file system but has limits on the number of access per hour | |
| filesystem = 'hf' | |
| input_repo_path = 'datasets/emvecchi/therapy_annotation' | |
| output_repo_path = 'datasets/emvecchi/therapy_annotation/pilot' | |
| to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate | |
| INPUT_FIELD_DEFAULT_VALUES = {'slider': 0, | |
| 'text': '', | |
| 'textarea': '', | |
| 'checkbox': False, | |
| 'radio': None, | |
| 'select_slider': 0, | |
| 'multiselect': [], | |
| 'likert_radio': None, | |
| 'y_n_radio': None} | |
| SHOW_HELP_ICON = False | |
| SHOW_VALIDATION_ERROR_MESSAGE = True | |
| ######################################################################################## | |
| if filesystem == 'hf': | |
| HF_TOKEN = os.environ.get("HF_TOKEN_WRITE") | |
| print("is none?", HF_TOKEN is None) | |
| hf_fs = HfFileSystem(token=HF_TOKEN) | |
| else: | |
| hf_fs = LocalFileSystem() | |
| def get_start_index(): | |
| base_dir = f"{output_repo_path}/{get_base_path()}" | |
| try: | |
| files = hf_fs.ls(base_dir) | |
| except Exception as e: | |
| return -3 | |
| return len(files) | |
| def read_data(): | |
| with hf_fs.open(input_repo_path + '/' + to_annotate_file_name) as f: | |
| return pd.read_csv(f) | |
| def read_saved_data(): | |
| _path = get_path() | |
| if hf_fs.exists(output_repo_path + '/' + _path): | |
| with hf_fs.open(output_repo_path + '/' + _path) as f: | |
| try: | |
| return json.load(f) | |
| except json.JSONDecodeError as e: | |
| print(e) | |
| return None | |
| # Write a remote file | |
| def save_data(data): | |
| if not hf_fs.exists(f"{output_repo_path}/{get_base_path()}"): | |
| hf_fs.mkdir(f"{output_repo_path}/{get_base_path()}") | |
| with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f: | |
| f.write(json.dumps(data)) | |
| def get_base_path(): | |
| return f"{st.session_state.user_id}" | |
| def get_path(): | |
| return f"{get_base_path()}/{st.session_state.current_index}.json" | |
| def display_dialogue(hf_path: str): | |
| with hf_fs.open(hf_path, "rb") as f: | |
| txt = f.read().decode("utf-8") | |
| txt = load_text(hf_path) | |
| st.markdown(txt, | |
| unsafe_allow_html=True) | |
| def _ensure_key(key: str, default): | |
| if key not in st.session_state: | |
| st.session_state[key] = default | |
| return st.session_state[key] | |
| SPEAKER_RE = re.compile(r'^\s*\*\*(T|P):\*\*\s*(.*)$') | |
| def read_md(path: str) -> str: | |
| """Read .md file from HF dataset or local disk.""" | |
| if filesystem == 'hf': | |
| with hf_fs.open(path, "rb") as f: | |
| return f.read().decode("utf-8") | |
| else: | |
| # local path | |
| return Path(path).read_text(encoding="utf-8") | |
| def wrap_para(text: str, width: int) -> list[str]: | |
| if not text.strip(): | |
| return [""] | |
| return textwrap.wrap( | |
| text.strip().replace("\u00A0", " "), | |
| width=width, | |
| break_long_words=False, | |
| break_on_hyphens=False, | |
| drop_whitespace=True, | |
| replace_whitespace=True, | |
| ) or [""] | |
| def md_dialogue_to_visual_lines(md_text: str, width: int) -> list[str]: | |
| """Return a list of wrapped visual lines with <strong>T:/P:</strong> on first line.""" | |
| md_text = md_text.replace("\r\n", "\n").replace("\r", "\n").strip("\n") | |
| paragraphs = re.split(r"\n\s*\n", md_text) | |
| out = [] | |
| for p in paragraphs: | |
| p = p.strip() | |
| if not p: | |
| out.append("") | |
| continue | |
| m = SPEAKER_RE.match(p) | |
| if m: | |
| speaker, content = m.group(1), m.group(2) | |
| wrapped = wrap_para(content, width) | |
| out.append(f"<strong>{speaker}:</strong> {py_html.escape(wrapped[0])}".rstrip()) | |
| for w in wrapped[1:]: | |
| out.append(py_html.escape(w)) | |
| else: | |
| for w in wrap_para(p, width): | |
| out.append(py_html.escape(w)) | |
| out.append("") # blank between paragraphs | |
| if out and out[-1] == "": out.pop() | |
| return out | |
| def render_markdown_simple(md_path, height_px=520): | |
| md = read_md(md_path) | |
| st.markdown( | |
| f"<div style='max-height:{height_px}px; overflow:auto'>{markdown(md, extensions=['extra','codehilite'])}</div>", | |
| unsafe_allow_html=True | |
| ) | |
| def render_dialogue(md_path: str, | |
| width_chars: int = 100, | |
| height_px: int = 520, | |
| font_family: str = "system-ui, -apple-system, 'Segoe UI', Roboto, Helvetica, Arial, 'Noto Sans', sans-serif", | |
| font_size: str = "1.05rem", | |
| show_border: bool = False): | |
| """Render the .md dialogue with line numbers inside an iframe to avoid Streamlit CSS quirks.""" | |
| text = read_md(md_path) | |
| lines = md_dialogue_to_visual_lines(text, width_chars) | |
| border_css = "1px solid #e6e6e6" if show_border else "none" | |
| radius_css = ".6rem" if show_border else "0" | |
| padding_css = ".8rem 1rem" if show_border else "0" | |
| # Build stable HTML (table keeps gutter aligned) | |
| rows = "\n".join( | |
| f"<tr><td class='num'>{i}</td><td class='txt'>{ln or ' '}</td></tr>" | |
| for i, ln in enumerate(lines, 1) | |
| ) | |
| html_doc = f""" | |
| <!doctype html> | |
| <meta charset="utf-8"> | |
| <style> | |
| :root {{ | |
| --font-text: {font_family}; | |
| --font-size: {font_size}; | |
| }} | |
| html, body {{ | |
| margin:0; padding:0; background:transparent; | |
| font-family: var(--font-text); font-size: var(--font-size); | |
| -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; | |
| font-synthesis: none; /* avoid auto fake-bold */ | |
| }} | |
| .root {{ | |
| border:{border_css}; border-radius:{radius_css}; padding:{padding_css}; | |
| height:{height_px}px; overflow:auto; background:transparent; | |
| line-height:1.5; | |
| }} | |
| table {{ border-collapse:collapse; table-layout:fixed; width:max-content; max-width:100%; }} | |
| td {{ padding:0; vertical-align:top; }} | |
| .num {{ | |
| width:4ch; padding-right:1ch; text-align:right; color:rgba(0,0,0,.55); | |
| user-select:none; font-variant-numeric: tabular-nums; | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, 'Liberation Mono', monospace; | |
| }} | |
| .txt {{ | |
| white-space: pre-wrap; word-break: break-word; max-width:{width_chars}ch; | |
| }} | |
| strong {{ font-weight:700; }} | |
| </style> | |
| <div class="root"> | |
| <table> | |
| {rows} | |
| </table> | |
| </div> | |
| """ | |
| # iframe height = inner height; no external scrollbars flashing | |
| st_html(html_doc, height=height_px + (16 if show_border else 0), scrolling=False) | |
| def _get_value_for_field(f: Field, index: int): | |
| """Return the session value for input fields or None.""" | |
| if f.type in INPUT_FIELD_DEFAULT_VALUES: | |
| key = f.name + str(index) | |
| return st.session_state.get(key, INPUT_FIELD_DEFAULT_VALUES[f.type]) | |
| return None | |
| def _is_default_value(f: Field, val): | |
| """Check if current value equals the default for this widget type.""" | |
| return val == INPUT_FIELD_DEFAULT_VALUES.get(f.type) | |
| def validate_current_page(fields: List[Field], index: int) -> bool: | |
| """ | |
| Walk the field tree and verify mandatory inputs are filled. | |
| Honors `following_mandatory_values`: if a field's value is in that list, | |
| subsequent siblings become mandatory. | |
| """ | |
| ok = True | |
| def walk(nodes: List[Field], following_required: bool = False): | |
| nonlocal ok | |
| # local flag that can be turned on by a node for its following siblings | |
| req_for_following = following_required | |
| for f in nodes: | |
| if f.children: | |
| # containers/expanders: recurse; pass the current rule downward | |
| walk(f.children, req_for_following) | |
| continue | |
| if f.type in INPUT_FIELD_DEFAULT_VALUES: | |
| val = _get_value_for_field(f, index) | |
| # decide if THIS field is required | |
| required = bool(f.mandatory or req_for_following) | |
| # if required and default → error | |
| if required and _is_default_value(f, val): | |
| ok = False | |
| # update the "following required" rule for siblings | |
| if f.following_mandatory_values: | |
| try: | |
| if val in f.following_mandatory_values: | |
| req_for_following = True | |
| except Exception: | |
| pass | |
| # structural types (markdown/input_col/etc.) don’t affect validation | |
| walk(fields, False) | |
| if not ok: | |
| st.error("Please fill in all mandatory fields.") | |
| return ok | |
| #################################### Streamlit App #################################### | |
| # Function to navigate rows | |
| def navigate(index_change): | |
| st.session_state.step = 0 | |
| st.session_state.current_index += index_change | |
| # only works consistently if done before rerun | |
| js = ''' | |
| <script> | |
| setTimeout(function() { | |
| var titleElement = window.parent.document.querySelector("h1"); | |
| if (titleElement) { | |
| titleElement.scrollIntoView({behavior: "smooth", block: "start"}); | |
| } | |
| }, 100); | |
| </script> | |
| ''' | |
| st.components.v1.html(js, height=0) | |
| # https://discuss.streamlit.io/t/click-twice-on-button-for-changing-state/45633/2 | |
| # disable text input enter to submit | |
| # https://discuss.streamlit.io/t/text-input-how-to-disable-press-enter-to-apply/14457/6 | |
| components.html( | |
| """ | |
| <script> | |
| const inputs = window.parent.document.querySelectorAll('input'); | |
| inputs.forEach(input => { | |
| input.addEventListener('keydown', function(event) { | |
| if (event.key === 'Enter') { | |
| event.preventDefault(); | |
| } | |
| }); | |
| }); | |
| </script> | |
| """, | |
| height=0 | |
| ) | |
| st.rerun() | |
| def show_field(f: Field, index: int, data_collected): | |
| if f.type not in INPUT_FIELD_DEFAULT_VALUES.keys(): | |
| st.session_state.following_mandatory = False | |
| match f.type: | |
| case 'input_col': | |
| value = ( | |
| st.session_state.data.iloc[index][f.name] | |
| if f.name and f.name in st.session_state.data.columns | |
| else None | |
| ) | |
| #elif f.name == 'dialogue_name' and value: | |
| # render_dialogue( | |
| # os.path.join(input_repo_path, 'dialogues', value), | |
| # width_chars=115, height_px=520, show_border=False | |
| # ) | |
| if f.name == 'dialogue_name' and value: | |
| render_markdown_simple(os.path.join(input_repo_path, 'dialogues', st.session_state.batch, value), height_px=720) | |
| elif f.name == 'role_name' and value: | |
| render_markdown_simple(os.path.join(input_repo_path, 'role_descriptions', value), height_px=520) | |
| elif f.name == 'patient' and value: | |
| st.markdown(f"#### Patient: {value}") | |
| elif value not in (None, np.nan, ""): | |
| if f.title: | |
| st.write(f.title) | |
| st.write(value) | |
| case 'markdown': | |
| if f.other_params and f.other_params.get("use_dialogue_name"): | |
| row = st.session_state.data.iloc[index] | |
| dialogue_name = row["dialogue_name"] | |
| md_path = os.path.join(input_repo_path, "dialogues", st.session_state.batch, dialogue_name) | |
| content = load_text(md_path) | |
| st.markdown(content, unsafe_allow_html=True) | |
| elif f.other_params and f.other_params.get("use_roledesc_name"): | |
| row = st.session_state.data.iloc[index] | |
| role_name = row["role_name"] | |
| md_path = os.path.join(input_repo_path, "role_descriptions", role_name) | |
| content = load_text(md_path) | |
| st.markdown(content, unsafe_allow_html=True) | |
| elif f.other_params and f.other_params.get("use_guidelines"): | |
| md_path = os.path.join(input_repo_path, "guidelines.md") | |
| content = load_text(md_path) | |
| st.markdown(content, unsafe_allow_html=True) | |
| elif f.other_params and f.other_params.get("instruction_content"): | |
| content = f.other_params.get("content", "") | |
| if content: | |
| st.markdown(content, unsafe_allow_html=True) | |
| else: | |
| path = f.other_params.get("path") if f.other_params else None | |
| if path: | |
| content = load_text(os.path.join(input_repo_path, path)) | |
| st.markdown(content, unsafe_allow_html=True) | |
| else: | |
| st.markdown(f.title) | |
| case 'expander': | |
| with (st.expander(f.title) if f.type == 'expander' else st.container(border=True)): | |
| for child in f.children: | |
| show_field(child, index, data_collected) | |
| case 'container': | |
| with st.container(border=True): | |
| st.markdown(f.title) | |
| for child in (f.children or []): | |
| show_field(child, index, data_collected) | |
| case 'skip_checkbox': | |
| st.checkbox(f.title, key=f.name, value=False) | |
| return | |
| else: | |
| key = f.name + str(index) | |
| # track the logical field name once per page; prep_and_save_data relies on this | |
| st.session_state.data_inputs_keys.append(f.name) | |
| if data_collected and f.name in data_collected: | |
| st.session_state[key] = data_collected[f.name] | |
| elif key not in st.session_state: | |
| st.session_state[key] = INPUT_FIELD_DEFAULT_VALUES[f.type] | |
| value = st.session_state[key] | |
| label = f.title | |
| if not SHOW_HELP_ICON: | |
| label = f'**{label}**\n\n{f.help}' if f.help else label | |
| validation_error = False | |
| if st.session_state.form_displayed == st.session_state.current_index: | |
| if st.session_state.following_mandatory and f.skip_mandatory: | |
| st.session_state.following_mandatory = False | |
| if f.following_mandatory_values and st.session_state[key] in f.following_mandatory_values: | |
| st.session_state.following_mandatory = True | |
| if f.mandatory or st.session_state.following_mandatory: | |
| if value == INPUT_FIELD_DEFAULT_VALUES[f.type]: | |
| #if st.session_state[key] == INPUT_FIELD_DEFAULT_VALUES[f.type]: | |
| st.session_state.valid = False | |
| validation_error = True | |
| elif f.following_mandatory_values and st.session_state[key] in f.following_mandatory_values: | |
| st.session_state.following_mandatory = True | |
| if f.mandatory or st.session_state.following_mandatory: | |
| label += (" :red[* required!]" if (validation_error and not SHOW_VALIDATION_ERROR_MESSAGE) else " :red[*]") | |
| #f.title += " :red[* required!]" if (validation_error and not SHOW_VALIDATION_ERROR_MESSAGE) else' :red[*]' | |
| help_text = None if not SHOW_HELP_ICON else f.help | |
| #f.help = None | |
| value = st.session_state[key] | |
| match f.type: | |
| case 'checkbox': | |
| st.checkbox(label, key=key, value=value, help=help_text) | |
| case 'radio': | |
| labels = (f.other_params.get('labels') | |
| if f.other_params and f.other_params.get('labels') | |
| else default_labels) | |
| st.radio( | |
| label, | |
| options=range(len(labels)), | |
| format_func=lambda x: labels[x], | |
| key=key, | |
| index=value, | |
| help=help_text, | |
| horizontal=False | |
| ) | |
| case 'slider': | |
| st.slider(label, min_value=0, max_value=6, step=1, key=key, | |
| value=value, help=help_text) | |
| case 'select_slider': | |
| labels = (f.other_params.get('labels') | |
| if f.other_params and f.other_params.get('labels') | |
| else default_labels) | |
| st.select_slider( | |
| label, | |
| options=[0, 20, 40, 60, 80, 100], | |
| format_func=lambda x: labels[x // 20], | |
| key=key, | |
| value=value, | |
| help=help_text | |
| ) | |
| case 'multiselect': | |
| choices = (f.other_params.get('choices') | |
| if f.other_params and f.other_params.get('choices') | |
| else default_choices) | |
| st.multiselect( | |
| label, | |
| options=choices, | |
| format_func=lambda x: x, | |
| key=key, | |
| default=value, | |
| max_selections=3, | |
| help=help_text | |
| ) | |
| case 'likert_radio': | |
| labels = (f.other_params.get('labels') | |
| if f.other_params and f.other_params.get('labels') | |
| else default_labels) | |
| st.radio( | |
| label, | |
| options=[0, 1, 2, 3, 4], | |
| format_func=lambda x: labels[x], | |
| key=key, | |
| index=value, | |
| help=help_text, | |
| horizontal=True | |
| ) | |
| case 'y_n_radio': | |
| labels = (f.other_params.get('labels') | |
| if f.other_params and f.other_params.get('labels') | |
| else yes_no_labels) | |
| st.radio( | |
| label, | |
| options=[0, 1], | |
| format_func=lambda x: labels[x], | |
| key=key, | |
| index=value, | |
| help=help_text, | |
| horizontal=True | |
| ) | |
| case 'text': | |
| st.text_input(label, key=key, value=(value if value is not None else ""), | |
| max_chars=None, help=help_text) | |
| case 'textarea': | |
| st.text_area(label, key=key, value=(value if value is not None else ""), | |
| max_chars=None, help=help_text) | |
| def show_fields(fields: List[Field]): | |
| index = st.session_state.current_index | |
| data_collected = read_saved_data() | |
| st.session_state.data_inputs_keys = [] # rebuilt each render | |
| st.session_state.following_mandatory = False | |
| for field in fields: | |
| show_field(field, index, data_collected) | |
| # mark that the page has been rendered at least once (if you still use this elsewhere) | |
| st.session_state.form_displayed = st.session_state.current_index | |
| def iter_all_input_fields(): | |
| """Yield all Field objects which are real input widgets, across all steps.""" | |
| def walk(nodes): | |
| for f in nodes: | |
| if f.children: | |
| walk(f.children) | |
| elif f.type in INPUT_FIELD_DEFAULT_VALUES: | |
| yield f | |
| for step_fields in STEPS: | |
| yield from walk(step_fields) | |
| def prep_and_save_data(index, skip_sample, completed: bool): | |
| existing = read_saved_data() or {} | |
| base = {} | |
| if 0 <= index < len(st.session_state.data): | |
| base = st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() | |
| payload = { | |
| **existing, | |
| **base, | |
| 'user_id': st.session_state.user_id, | |
| 'batch': st.session_state.batch, | |
| 'index': st.session_state.current_index, | |
| 'skip': skip_sample, | |
| 'completed': int(completed), | |
| } | |
| for k in st.session_state.data_inputs_keys: | |
| key = k + str(index) | |
| if key in st.session_state: | |
| payload[k] = st.session_state[key] | |
| for f in iter_all_input_fields(): | |
| key = f.name + str(index) | |
| val = st.session_state.get(key, INPUT_FIELD_DEFAULT_VALUES[f.type]) | |
| payload[f.name] = val | |
| save_data(payload) | |
| #st.set_page_config(layout='wide') | |
| # Title of the app | |
| st.title("Annotation of Simulated Patients") | |
| st.markdown( | |
| """<style> | |
| div[data-testid="stMarkdownContainer"] > p { | |
| font-size: 1rem; | |
| } | |
| section.main > div {max-width:80rem} | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def add_annotation_guidelines(): | |
| guidelines_text = load_text("guidelines.md") | |
| st.markdown( | |
| f"<details><summary><b>Annotation Guidelines</b></summary><div>{guidelines_text}</div></details><br>", | |
| unsafe_allow_html=True) | |
| if 'unacceptable_response' in st.session_state and st.session_state.unacceptable_response: | |
| #error_message = "You are not eligible for this study.<br><br>" | |
| error_message = "<br>" | |
| if st.session_state.current_index >= -5: | |
| error_message += ( | |
| "Thank you for your time! You will receive a small compensation for your contribution up to now. <br><br>" | |
| f'Please return to the study and copy/paste this code: <b>{failed_sanity_check_code}</b>, or ' | |
| f'<a href="https://app.prolific.com/submissions/complete?cfc={failed_sanity_check_code}">Click Here</a>' | |
| ) | |
| # Display the error message using custom HTML | |
| st.markdown( | |
| f""" | |
| <div style="background-color: #f8d7da; color: #721c24; padding: 10px; border-radius: 5px; border: 1px solid #f5c6cb;"> | |
| <h4>Error: You are not eligible for this study.</h4> {error_message} | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.stop() | |
| # Load the data to annotate | |
| if 'data' not in st.session_state: | |
| st.session_state.data = read_data() | |
| # user id | |
| user_id_from_url = get_param_from_url("user_id") | |
| if user_id_from_url and not st.session_state.get("user_id"): | |
| st.session_state.user_id = user_id_from_url | |
| # batch | |
| batch_from_url = get_param_from_url("batch") | |
| if batch_from_url and not st.session_state.get("batch"): | |
| st.session_state.batch = batch_from_url | |
| # current index | |
| if 'current_index' not in st.session_state: | |
| start_index = get_start_index() | |
| target_index = start_index | |
| if start_index > 0: | |
| last_idx = start_index - 1 | |
| last_path = f"{output_repo_path}/{get_base_path()}/{last_idx}.json" | |
| try: | |
| with hf_fs.open(last_path, "rb") as f: | |
| last_data = json.load(f) | |
| except Exception: | |
| last_data = {} | |
| completed_val = last_data.get("completed", 1) | |
| is_completed = bool(completed_val) | |
| if not is_completed: | |
| target_index = last_idx | |
| if target_index < len(st.session_state.data)-1: | |
| st.session_state.current_index = target_index | |
| else: | |
| st.session_state.current_index = target_index+1 | |
| st.session_state.form_displayed = -3 | |
| if 'step' not in st.session_state: | |
| st.session_state.step = 0 | |
| def add_validated_submit(fields, message): | |
| st.session_state.form_displayed = st.session_state.current_index | |
| if st.form_submit_button("Submit"): | |
| if all(not x for x in fields): | |
| st.error(message) | |
| else: | |
| navigate(1) | |
| def add_checked_submit(): | |
| check = st.checkbox('I agree', key='consent') | |
| add_validated_submit([check], "Please agree to give your consent to proceed") | |
| def add_guidelines_submit(): | |
| check = st.checkbox('I have read the guidelines in detail, and I understand the study', key='consent') | |
| add_validated_submit([check], "Please confirm to have carefully read the guidelines.") | |
| def load_text(path: str) -> str: | |
| if filesystem == 'hf': | |
| with hf_fs.open(path, "rb") as f: | |
| return f.read().decode("utf-8") | |
| else: | |
| with open(path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| if st.session_state.current_index == -3: | |
| with st.form("data_form"): | |
| st.markdown(consent_text) | |
| add_checked_submit() | |
| elif st.session_state.current_index == -2: | |
| with st.form("data_form"): | |
| st.markdown("# Annotation Guidelines") | |
| md_path = os.path.join(input_repo_path, "guidelines.md") | |
| content = load_text(md_path) | |
| st.markdown(content, unsafe_allow_html=True) | |
| add_guidelines_submit() | |
| elif st.session_state.current_index == -1: | |
| if st.session_state.get('user_id') and st.session_state.get("batch"): | |
| navigate(1) | |
| else: | |
| with st.form("user_id_form"): | |
| st.session_state.user_id = st.text_input("User ID", value="") | |
| add_validated_submit([st.session_state.user_id], "Please enter a valid user ID") | |
| elif st.session_state.current_index < len(st.session_state.data): | |
| step = st.session_state.step | |
| total_steps = len(STEPS) | |
| show_fields(STEPS[step]) | |
| c1, c2, c3 = st.columns([1,5,2]) | |
| with c1: | |
| if st.button("<< Previous"): | |
| if step > 0: | |
| st.session_state.step -= 1 | |
| st.rerun() | |
| else: | |
| st.session_state.current_index -= 1 | |
| st.session_state.step = total_steps - 1 | |
| st.rerun() | |
| with c2: | |
| label = " **Submit >>**" if step < total_steps - 1 else "**Submit & next session**" | |
| if st.button(label): | |
| if validate_current_page(STEPS[step], st.session_state.current_index): | |
| is_last_page = (step == total_steps - 1) | |
| with st.spinner("saving"): | |
| prep_and_save_data(st.session_state.current_index, | |
| ('skip' in st.session_state and st.session_state['skip']), | |
| completed=is_last_page) | |
| if is_last_page: | |
| st.success("Saved!") | |
| navigate(1) | |
| else: | |
| st.session_state.step += 1 | |
| st.rerun() | |
| elif st.session_state.current_index == len(st.session_state.data): | |
| st.write(f"**Thank you for taking part in this study!** \n ") | |
| c1, c2, c3 = st.columns([1,5,2]) | |
| step = st.session_state.step | |
| total_steps = len(STEPS) | |
| with c1: | |
| if st.button("<< Previous"): | |
| if step > 0: | |
| st.session_state.step -= 1 | |
| st.rerun() | |
| else: | |
| st.session_state.current_index -= 1 | |
| st.session_state.step = total_steps - 1 | |
| st.rerun() | |
| if 0 <= st.session_state.current_index < len(st.session_state.data): | |
| st.write(f"Session {st.session_state.current_index + 1} out of {len(st.session_state.data)}") | |
| st.markdown( | |
| """<style> | |
| div[data-testid="InputInstructions"] { | |
| visibility: hidden; | |
| } | |
| </style>""", unsafe_allow_html=True | |
| ) |