MaxBenKre
commited on
Commit
·
448ed98
1
Parent(s):
abcb2e3
correct gui code
Browse files- Dockerfile +1 -1
- README.md +1 -6
- src/Start_Page.py +60 -0
- src/__init__.py +0 -0
- src/gui_elements/__init__.py +0 -0
- src/gui_elements/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gui_elements/__pycache__/output_manager.cpython-312.pyc +0 -0
- src/gui_elements/__pycache__/paginator.cpython-312.pyc +0 -0
- src/gui_elements/__pycache__/stateful_widget.cpython-312.pyc +0 -0
- src/gui_elements/output_manager.py +61 -0
- src/gui_elements/paginator.py +68 -0
- src/gui_elements/stateful_widget.py +53 -0
- src/pages/01_Prompt_Configuration.py +311 -0
- src/pages/02_Option_Prompt.py +197 -0
- src/pages/03_Inference_Setting.py +230 -0
- src/pages/04_Final_Overview.py +289 -0
- src/pages/__init__.py +0 -0
- src/pages/old_pages/_01_Basic_Prompt_Settings.py +124 -0
- src/pages/old_pages/_03_Prepare_Prompts.py +178 -0
- src/streamlit_app.py +0 -40
- src/surveyGenGui.egg-info/PKG-INFO +3 -0
- src/surveyGenGui.egg-info/SOURCES.txt +17 -0
- src/surveyGenGui.egg-info/dependency_links.txt +1 -0
- src/surveyGenGui.egg-info/top_level.txt +4 -0
Dockerfile
CHANGED
|
@@ -17,4 +17,4 @@ EXPOSE 8501
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "src/
|
|
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
+
ENTRYPOINT ["streamlit", "run", "src/Start_Page.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
README.md
CHANGED
|
@@ -12,9 +12,4 @@ short_description: GUI for the QSTN Framework
|
|
| 12 |
license: mit
|
| 13 |
---
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
|
| 18 |
-
|
| 19 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 20 |
-
forums](https://discuss.streamlit.io).
|
|
|
|
| 12 |
license: mit
|
| 13 |
---
|
| 14 |
|
| 15 |
+
# QSTN GUI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/Start_Page.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from io import StringIO
|
| 4 |
+
|
| 5 |
+
from qstn.survey_manager import SurveyCreator
|
| 6 |
+
|
| 7 |
+
from qstn.prompt_builder import LLMPrompt
|
| 8 |
+
|
| 9 |
+
st.set_page_config(layout="wide")
|
| 10 |
+
st.title("Questionnaire")
|
| 11 |
+
|
| 12 |
+
col1, col2 = st.columns(2)
|
| 13 |
+
|
| 14 |
+
# def save_dataframe(uploaded_file):
|
| 15 |
+
|
| 16 |
+
# try:
|
| 17 |
+
# df = pd.read_csv(uploaded_file)
|
| 18 |
+
# # Store the DataFrame in session state
|
| 19 |
+
# st.session_state.df = df
|
| 20 |
+
# except Exception as e:
|
| 21 |
+
# st.error(f"Error reading the file: {e}")
|
| 22 |
+
|
| 23 |
+
# return df
|
| 24 |
+
|
| 25 |
+
df_population = None
|
| 26 |
+
df_questionnaire = None
|
| 27 |
+
|
| 28 |
+
with col1:
|
| 29 |
+
uploaded_questionnaire = st.file_uploader("Select a questionnaire to start with")
|
| 30 |
+
if uploaded_questionnaire is not None:
|
| 31 |
+
# bytes_data = uploaded_questionnaire.getvalue()
|
| 32 |
+
# st.write(bytes_data)
|
| 33 |
+
|
| 34 |
+
df_questionnaire = pd.read_csv(uploaded_questionnaire)
|
| 35 |
+
#dataframe = save_dataframe(uploaded_questionnaire)
|
| 36 |
+
|
| 37 |
+
st.write(df_questionnaire)
|
| 38 |
+
|
| 39 |
+
with col2:
|
| 40 |
+
uploaded_population = st.file_uploader("Select a population to start with")
|
| 41 |
+
if uploaded_population is not None:
|
| 42 |
+
# bytes_data = uploaded_population.getvalue()
|
| 43 |
+
# st.write(bytes_data)
|
| 44 |
+
|
| 45 |
+
df_population = pd.read_csv(uploaded_population)
|
| 46 |
+
#dataframe = save_dataframe(uploaded_population)
|
| 47 |
+
|
| 48 |
+
st.write(df_population)
|
| 49 |
+
|
| 50 |
+
disabled = True
|
| 51 |
+
|
| 52 |
+
if df_population is not None and df_questionnaire is not None:
|
| 53 |
+
disabled = False
|
| 54 |
+
|
| 55 |
+
st.divider()
|
| 56 |
+
|
| 57 |
+
if st.button("Confirm and Prepare Questionnaire", type="primary", disabled=disabled, use_container_width=True):
|
| 58 |
+
questionnaires: list[LLMPrompt] = SurveyCreator.from_dataframe(df_population, df_questionnaire)
|
| 59 |
+
st.session_state.questionnaires = questionnaires
|
| 60 |
+
st.switch_page("pages/01_Prompt_Configuration.py")
|
src/__init__.py
ADDED
|
File without changes
|
src/gui_elements/__init__.py
ADDED
|
File without changes
|
src/gui_elements/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
src/gui_elements/__pycache__/output_manager.cpython-312.pyc
ADDED
|
Binary file (2.8 kB). View file
|
|
|
src/gui_elements/__pycache__/paginator.cpython-312.pyc
ADDED
|
Binary file (3.18 kB). View file
|
|
|
src/gui_elements/__pycache__/stateful_widget.cpython-312.pyc
ADDED
|
Binary file (2.82 kB). View file
|
|
|
src/gui_elements/output_manager.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import io
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
import contextlib
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
@contextmanager
|
| 8 |
+
def st_capture(output_func):
|
| 9 |
+
with io.StringIO() as stdout, contextlib.redirect_stdout(stdout):
|
| 10 |
+
old_write = stdout.write
|
| 11 |
+
|
| 12 |
+
def new_write(string):
|
| 13 |
+
ret = old_write(string)
|
| 14 |
+
output_func(stdout.getvalue())
|
| 15 |
+
return ret
|
| 16 |
+
|
| 17 |
+
stdout.write = new_write
|
| 18 |
+
yield
|
| 19 |
+
|
| 20 |
+
class TqdmToStreamlit(io.StringIO):
|
| 21 |
+
"""
|
| 22 |
+
A custom file-like object that redirects tqdm's output to Streamlit's st.progress
|
| 23 |
+
and st.text widgets.
|
| 24 |
+
"""
|
| 25 |
+
# def __init__(self, progress_bar, text_element):
|
| 26 |
+
# super().__init__()
|
| 27 |
+
# self.progress_bar = progress_bar
|
| 28 |
+
# self.text_element = text_element
|
| 29 |
+
# self.last_progress = 0
|
| 30 |
+
|
| 31 |
+
def __init__(self, text_element):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.text_element = text_element
|
| 34 |
+
|
| 35 |
+
def write(self, buf):
|
| 36 |
+
# We can use this code if we want to have a nicer progress bar.
|
| 37 |
+
# match = re.search(r"(\d+)%\|", buf)
|
| 38 |
+
# if match:
|
| 39 |
+
# progress = int(match.group(1))
|
| 40 |
+
# if progress > self.last_progress:
|
| 41 |
+
# self.progress_bar.progress(progress)
|
| 42 |
+
# self.text_element.text(buf.strip())
|
| 43 |
+
# self.last_progress = progress
|
| 44 |
+
self.text_element.text(buf.strip())
|
| 45 |
+
|
| 46 |
+
def flush(self):
|
| 47 |
+
pass # No-op
|
| 48 |
+
|
| 49 |
+
class QueueIO(io.StringIO):
|
| 50 |
+
"""
|
| 51 |
+
A custom file-like object that writes to a queue.
|
| 52 |
+
Used to capture stdout/stderr.
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, q):
|
| 55 |
+
self.queue = q
|
| 56 |
+
|
| 57 |
+
def write(self, buf):
|
| 58 |
+
self.queue.put(buf)
|
| 59 |
+
|
| 60 |
+
def flush(self):
|
| 61 |
+
pass
|
src/gui_elements/paginator.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
def paginator(items, session_state_key):
|
| 4 |
+
"""
|
| 5 |
+
Creates a reusable paginator component for a list of items.
|
| 6 |
+
|
| 7 |
+
Args:
|
| 8 |
+
items (list): The list of items to paginate through.
|
| 9 |
+
session_state_key (str): A unique key for storing the current index in
|
| 10 |
+
st.session_state.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
Any: The currently selected item from the list.
|
| 14 |
+
"""
|
| 15 |
+
list_length = len(items)
|
| 16 |
+
|
| 17 |
+
# Initialize the session state for the current index if it doesn't exist.
|
| 18 |
+
if session_state_key not in st.session_state:
|
| 19 |
+
st.session_state[session_state_key] = 0
|
| 20 |
+
|
| 21 |
+
def next_item():
|
| 22 |
+
"""Increments the index, wrapping around to the start if it reaches the end."""
|
| 23 |
+
if list_length > 0:
|
| 24 |
+
st.session_state[session_state_key] = (st.session_state[session_state_key] + 1) % list_length
|
| 25 |
+
|
| 26 |
+
def prev_item():
|
| 27 |
+
"""Decrements the index, wrapping around to the end if it goes below zero."""
|
| 28 |
+
if list_length > 0:
|
| 29 |
+
st.session_state[session_state_key] = (st.session_state[session_state_key] - 1 + list_length) % list_length
|
| 30 |
+
|
| 31 |
+
# Create the navigation columns.
|
| 32 |
+
col1, col2, col3 = st.columns([2, 3, 2])
|
| 33 |
+
|
| 34 |
+
with col1:
|
| 35 |
+
st.button("⬅️ Previous", on_click=prev_item, use_container_width=True, disabled=(list_length <= 1))
|
| 36 |
+
|
| 37 |
+
# Display the current position and the popover for jumping to a specific item.
|
| 38 |
+
with col2:
|
| 39 |
+
if list_length > 0:
|
| 40 |
+
current_num = st.session_state[session_state_key] + 1
|
| 41 |
+
|
| 42 |
+
popover = st.popover(f"Item {current_num} of {list_length}", use_container_width=True)
|
| 43 |
+
|
| 44 |
+
with popover:
|
| 45 |
+
st.markdown("Jump to a specific item:")
|
| 46 |
+
target_index_input = st.number_input(
|
| 47 |
+
"Item Number",
|
| 48 |
+
min_value=1,
|
| 49 |
+
max_value=list_length,
|
| 50 |
+
value=current_num,
|
| 51 |
+
step=1,
|
| 52 |
+
label_visibility="collapsed"
|
| 53 |
+
)
|
| 54 |
+
if st.button("Go"):
|
| 55 |
+
st.session_state[session_state_key] = target_index_input - 1
|
| 56 |
+
st.rerun()
|
| 57 |
+
else:
|
| 58 |
+
st.text("No items to display")
|
| 59 |
+
|
| 60 |
+
# "Next" button in the third column.
|
| 61 |
+
with col3:
|
| 62 |
+
st.button("Next ➡️", on_click=next_item, use_container_width=True, disabled=(list_length <= 1))
|
| 63 |
+
|
| 64 |
+
# Return the currently selected item.
|
| 65 |
+
if list_length > 0:
|
| 66 |
+
return st.session_state[session_state_key]
|
| 67 |
+
else:
|
| 68 |
+
return None
|
src/gui_elements/stateful_widget.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
class StatefulWidgets:
|
| 4 |
+
"""
|
| 5 |
+
A class to create Streamlit widgets with encapsulated session state management.
|
| 6 |
+
"""
|
| 7 |
+
def __init__(self, prefix: str = "_"):
|
| 8 |
+
"""
|
| 9 |
+
Initializes the StatefulWidgets class.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
prefix (str): A prefix to use for the widget's internal session state key.
|
| 13 |
+
"""
|
| 14 |
+
self.prefix = prefix
|
| 15 |
+
|
| 16 |
+
def _store_value(self, key: str) -> None:
|
| 17 |
+
"""
|
| 18 |
+
Callback function to store the widget's value from the internal state
|
| 19 |
+
to the main session state.
|
| 20 |
+
"""
|
| 21 |
+
st.session_state[key] = st.session_state[f"{self.prefix}{key}"]
|
| 22 |
+
|
| 23 |
+
def _initialize_and_load(self, key: str, initial_value) -> None:
|
| 24 |
+
"""
|
| 25 |
+
Initializes the session state for a given key if it doesn't exist,
|
| 26 |
+
and then loads this value into the widget's internal state.
|
| 27 |
+
"""
|
| 28 |
+
if key not in st.session_state:
|
| 29 |
+
st.session_state[key] = initial_value
|
| 30 |
+
st.session_state[f"{self.prefix}{key}"] = st.session_state[key]
|
| 31 |
+
|
| 32 |
+
def create(self, widget_func, key: str, *args, initial_value=None, **kwargs):
|
| 33 |
+
"""
|
| 34 |
+
Creates a Streamlit widget and binds its state.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
widget_func: The Streamlit widget function (e.g., st.text_input).
|
| 38 |
+
key (str): The key for the main session state.
|
| 39 |
+
*args: Positional arguments to pass to the widget function.
|
| 40 |
+
initial_value: The initial value to set in the session state if not present.
|
| 41 |
+
**kwargs: Keyword arguments to pass to the widget function.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
The created Streamlit widget.
|
| 45 |
+
"""
|
| 46 |
+
self._initialize_and_load(key, initial_value)
|
| 47 |
+
return widget_func(
|
| 48 |
+
*args,
|
| 49 |
+
key=f"{self.prefix}{key}",
|
| 50 |
+
on_change=self._store_value,
|
| 51 |
+
args=(key,),
|
| 52 |
+
**kwargs,
|
| 53 |
+
)
|
src/pages/01_Prompt_Configuration.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from qstn.prompt_builder import LLMPrompt
|
| 3 |
+
from qstn.utilities.constants import QuestionnairePresentation
|
| 4 |
+
from qstn.utilities import placeholder
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from gui_elements.paginator import paginator
|
| 8 |
+
from gui_elements.stateful_widget import StatefulWidgets
|
| 9 |
+
|
| 10 |
+
# CONSTANTS FOR FIELDS
|
| 11 |
+
question_stem_field = "Question Stem"
|
| 12 |
+
randomize_order_tick = "Randomize the order of items"
|
| 13 |
+
system_prompt_field = "System prompt"
|
| 14 |
+
prompt_field = "Prompt"
|
| 15 |
+
change_all_system_prompts_checkbox = "system_change_all"
|
| 16 |
+
change_all_prompts_checkbox = "prompts_change_all"
|
| 17 |
+
|
| 18 |
+
field_ids = [question_stem_field, randomize_order_tick]
|
| 19 |
+
|
| 20 |
+
st.set_page_config(layout="wide")
|
| 21 |
+
st.title("Prompt Configuration")
|
| 22 |
+
st.write(
|
| 23 |
+
"This interface allows you configure how the questions are prompted to the LLM and the overall prompt structure. "
|
| 24 |
+
"These options are applied to every questionnaire in your survey."
|
| 25 |
+
)
|
| 26 |
+
st.page_link("pages/02_Option_Prompt.py", label="Click here to adjust the answer options.")
|
| 27 |
+
st.divider()
|
| 28 |
+
|
| 29 |
+
@st.cache_data
|
| 30 |
+
def create_stateful_widget() -> StatefulWidgets:
|
| 31 |
+
return StatefulWidgets()
|
| 32 |
+
|
| 33 |
+
state = create_stateful_widget()
|
| 34 |
+
|
| 35 |
+
if "questionnaires" not in st.session_state:
|
| 36 |
+
st.error("You need to first upload a questionnaire and the population you want to survey.")
|
| 37 |
+
st.stop()
|
| 38 |
+
disabled = True
|
| 39 |
+
else:
|
| 40 |
+
disabled = False
|
| 41 |
+
|
| 42 |
+
if 'current_index' not in st.session_state:
|
| 43 |
+
st.session_state.current_index = 0
|
| 44 |
+
|
| 45 |
+
#current_questionnaire_id = paginator(st.session_state.questionnaires, "current_questionnaire_index_prepare")
|
| 46 |
+
current_questionnaire_id = paginator(st.session_state.questionnaires, "current_questionnaire_index_prompt")
|
| 47 |
+
|
| 48 |
+
if not "temporary_questionnaire" in st.session_state:
|
| 49 |
+
st.session_state.temporary_questionnaire = st.session_state.questionnaires[0].duplicate()
|
| 50 |
+
|
| 51 |
+
if not "base_questionnaire" in st.session_state:
|
| 52 |
+
st.session_state.base_questionnaire = st.session_state.temporary_questionnaire.duplicate()
|
| 53 |
+
|
| 54 |
+
def process_inputs(input: Any, field_id: str) -> str:
|
| 55 |
+
if "survey_options" in st.session_state:
|
| 56 |
+
survey_options = st.session_state.survey_options
|
| 57 |
+
else:
|
| 58 |
+
survey_options = None
|
| 59 |
+
|
| 60 |
+
if field_id == question_stem_field:
|
| 61 |
+
LLMPrompt.prepare_prompt
|
| 62 |
+
st.session_state.temporary_questionnaire.prepare_prompt(
|
| 63 |
+
question_stem=input,
|
| 64 |
+
answer_options=survey_options,
|
| 65 |
+
randomized_item_order=randomize_order_bool,
|
| 66 |
+
)
|
| 67 |
+
st.session_state.base_questionnaire.prepare_prompt(
|
| 68 |
+
question_stem=input,
|
| 69 |
+
answer_options=survey_options,
|
| 70 |
+
randomized_item_order=randomize_order_bool,
|
| 71 |
+
)
|
| 72 |
+
elif field_id == randomize_order_tick:
|
| 73 |
+
if input == True:
|
| 74 |
+
st.session_state.temporary_questionnaire.prepare_prompt(
|
| 75 |
+
question_stem=question_stem_input,
|
| 76 |
+
answer_options=survey_options,
|
| 77 |
+
randomized_item_order=input,
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
st.session_state.temporary_questionnaire = st.session_state.base_questionnaire.duplicate()
|
| 81 |
+
|
| 82 |
+
def handle_change(field_id: str):
|
| 83 |
+
"""
|
| 84 |
+
This single callback handles changes from any text field.
|
| 85 |
+
It reads the input from session state using the unique key,
|
| 86 |
+
processes it, and saves the output to session state.
|
| 87 |
+
"""
|
| 88 |
+
input_key = f"input_{field_id}"
|
| 89 |
+
|
| 90 |
+
with st.spinner(f"Processing {field_id}..."):
|
| 91 |
+
# time.sleep(0.5) # Simulate work
|
| 92 |
+
process_inputs(st.session_state[input_key], field_id)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if "questionnaires" in st.session_state and st.session_state.questionnaires is not None:
|
| 96 |
+
try:
|
| 97 |
+
questionnaire = st.session_state.questionnaires[current_questionnaire_id].duplicate()
|
| 98 |
+
except IndexError:
|
| 99 |
+
st.error("Index is out of range. Resetting to the first item.")
|
| 100 |
+
current_questionnaire_id = 0
|
| 101 |
+
questionnaire = st.session_state.questionnaires[current_questionnaire_id].duplicate()
|
| 102 |
+
|
| 103 |
+
col1, col2 = st.columns(2, gap="large")
|
| 104 |
+
|
| 105 |
+
with col1:
|
| 106 |
+
st.subheader("⚙️ Configuration")
|
| 107 |
+
|
| 108 |
+
for field_id in field_ids:
|
| 109 |
+
input_key = f"input_{field_id}"
|
| 110 |
+
if not input_key in st.session_state:
|
| 111 |
+
if field_id == question_stem_field:
|
| 112 |
+
st.session_state[input_key] = st.session_state.temporary_questionnaire ._questions[0].question_stem
|
| 113 |
+
if field_id == randomize_order_tick:
|
| 114 |
+
st.session_state[input_key] = False
|
| 115 |
+
|
| 116 |
+
# Handle placeholder replacement before widget is created
|
| 117 |
+
input_key = f"input_{question_stem_field}"
|
| 118 |
+
if "placeholder_to_replace" in st.session_state and st.session_state.placeholder_to_replace:
|
| 119 |
+
current_text = st.session_state.get(input_key, "")
|
| 120 |
+
placeholder_shortcut = st.session_state.placeholder_to_replace["shortcut"]
|
| 121 |
+
placeholder_value = st.session_state.placeholder_to_replace["value"]
|
| 122 |
+
|
| 123 |
+
# Replace all occurrences of the shortcut (e.g., -Q) with the placeholder
|
| 124 |
+
if placeholder_shortcut in current_text:
|
| 125 |
+
st.session_state[input_key] = current_text.replace(placeholder_shortcut, placeholder_value)
|
| 126 |
+
else:
|
| 127 |
+
# Shortcut not found, append at the end
|
| 128 |
+
st.session_state[input_key] = current_text + f" {placeholder_value} "
|
| 129 |
+
|
| 130 |
+
st.session_state.placeholder_to_replace = None
|
| 131 |
+
st.rerun()
|
| 132 |
+
|
| 133 |
+
# --- Input Widgets (No Form) ---
|
| 134 |
+
question_stem_input = st.text_area(
|
| 135 |
+
question_stem_field,
|
| 136 |
+
key=f"input_{question_stem_field}",
|
| 137 |
+
# placeholder="e.g., How would you rate the following aspects of our service?",
|
| 138 |
+
#on_change=handle_change,
|
| 139 |
+
kwargs={'field_id': question_stem_field},
|
| 140 |
+
height=100,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# --- Placeholder Replacement Buttons ---
|
| 144 |
+
st.write("**Insert Placeholder:**")
|
| 145 |
+
|
| 146 |
+
# Define available placeholders with their shortcuts and character labels
|
| 147 |
+
# Format: (placeholder_value, shortcut, character_label, description)
|
| 148 |
+
available_placeholders = [
|
| 149 |
+
(placeholder.QUESTION_CONTENT, "-Q", "Q", "Question Content"),
|
| 150 |
+
(placeholder.PROMPT_OPTIONS, "-O", "O", "Prompt Options"),
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
# Create shortcuts list for the tip
|
| 154 |
+
shortcuts_list = ", ".join([f"`{shortcut}`" for _, shortcut, _, _ in available_placeholders])
|
| 155 |
+
st.caption(f"💡 Tip: Type shortcuts {shortcuts_list} in the text, then click the button to replace them with placeholders.")
|
| 156 |
+
|
| 157 |
+
# Create buttons in columns with consistent formatting
|
| 158 |
+
cols = st.columns(len(available_placeholders))
|
| 159 |
+
for i, (placeholder_value, shortcut, char_label, description) in enumerate(available_placeholders):
|
| 160 |
+
button_label = description # Use the actual placeholder name
|
| 161 |
+
button_key = f"btn_placeholder_{char_label}"
|
| 162 |
+
|
| 163 |
+
if cols[i].button(button_label, key=button_key, use_container_width=True, help=f"Replaces '{shortcut}' with {placeholder_value}"):
|
| 164 |
+
st.session_state.placeholder_to_replace = {
|
| 165 |
+
"shortcut": shortcut,
|
| 166 |
+
"value": placeholder_value
|
| 167 |
+
}
|
| 168 |
+
st.rerun()
|
| 169 |
+
|
| 170 |
+
randomize_order_bool = st.checkbox(
|
| 171 |
+
randomize_order_tick,
|
| 172 |
+
key=f"input_{randomize_order_tick}",
|
| 173 |
+
value=False,
|
| 174 |
+
#on_change=handle_change,
|
| 175 |
+
kwargs={'field_id': randomize_order_tick}
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
st.divider()
|
| 179 |
+
|
| 180 |
+
# System prompt and main prompt section (from Basic Prompt Settings)
|
| 181 |
+
new_system_prompt = st.text_area(
|
| 182 |
+
label=system_prompt_field,
|
| 183 |
+
key=f"{system_prompt_field}{current_questionnaire_id}",
|
| 184 |
+
value=questionnaire.system_prompt,
|
| 185 |
+
help="The system prompt the model is prompted with."
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
change_all_system = state.create(
|
| 189 |
+
st.checkbox,
|
| 190 |
+
key=change_all_system_prompts_checkbox,
|
| 191 |
+
label="On update: change all System Prompts",
|
| 192 |
+
help="If this is ticked, all system prompts will be changed to this.",
|
| 193 |
+
initial_value=False
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Handle placeholder replacement for main prompt before widget is created
|
| 197 |
+
prompt_key = f"{prompt_field}{current_questionnaire_id}"
|
| 198 |
+
if "main_prompt_placeholder_to_replace" in st.session_state and st.session_state.main_prompt_placeholder_to_replace:
|
| 199 |
+
current_prompt_text = st.session_state.get(prompt_key, questionnaire.prompt)
|
| 200 |
+
placeholder_shortcut = st.session_state.main_prompt_placeholder_to_replace["shortcut"]
|
| 201 |
+
placeholder_value = st.session_state.main_prompt_placeholder_to_replace["value"]
|
| 202 |
+
|
| 203 |
+
if placeholder_shortcut in current_prompt_text:
|
| 204 |
+
st.session_state[prompt_key] = current_prompt_text.replace(placeholder_shortcut, placeholder_value)
|
| 205 |
+
else:
|
| 206 |
+
st.session_state[prompt_key] = current_prompt_text + f" {placeholder_value} "
|
| 207 |
+
|
| 208 |
+
st.session_state.main_prompt_placeholder_to_replace = None
|
| 209 |
+
st.rerun()
|
| 210 |
+
|
| 211 |
+
new_prompt = st.text_area(
|
| 212 |
+
label=prompt_field,
|
| 213 |
+
key=prompt_key,
|
| 214 |
+
value=questionnaire.prompt,
|
| 215 |
+
help="Instructions that are given to the model before the questions."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Placeholder insertion buttons for main prompt
|
| 219 |
+
st.write("**Insert Placeholder in Main Prompt:**")
|
| 220 |
+
main_prompt_placeholders = [
|
| 221 |
+
(placeholder.PROMPT_QUESTIONS, "-P", "P", "Prompt Questions"),
|
| 222 |
+
(placeholder.PROMPT_OPTIONS, "-O", "O", "Prompt Options"),
|
| 223 |
+
(placeholder.PROMPT_AUTOMATIC_OUTPUT_INSTRUCTIONS, "-A", "A", "Automatic Output"),
|
| 224 |
+
(placeholder.JSON_TEMPLATE, "-J", "J", "JSON Template"),
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
main_shortcuts_list = ", ".join([f"`{shortcut}`" for _, shortcut, _, _ in main_prompt_placeholders])
|
| 228 |
+
st.caption(f"💡 Tip: Type shortcuts {main_shortcuts_list} in the main prompt, then click the button to replace them.")
|
| 229 |
+
|
| 230 |
+
cols_main = st.columns(len(main_prompt_placeholders))
|
| 231 |
+
for i, (placeholder_value, shortcut, char_label, description) in enumerate(main_prompt_placeholders):
|
| 232 |
+
button_key = f"btn_main_placeholder_{char_label}"
|
| 233 |
+
if cols_main[i].button(description, key=button_key, use_container_width=True, help=f"Replaces '{shortcut}' with {placeholder_value}"):
|
| 234 |
+
st.session_state.main_prompt_placeholder_to_replace = {
|
| 235 |
+
"shortcut": shortcut,
|
| 236 |
+
"value": placeholder_value
|
| 237 |
+
}
|
| 238 |
+
st.rerun()
|
| 239 |
+
|
| 240 |
+
change_all_questionnaire = state.create(
|
| 241 |
+
st.checkbox,
|
| 242 |
+
key=change_all_prompts_checkbox,
|
| 243 |
+
label="On update: change all questionnaire instructions",
|
| 244 |
+
help="If this is ticked, all questionnaire instructions will be changed to this.",
|
| 245 |
+
initial_value=False
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Place the corresponding output in the second column
|
| 249 |
+
with col2:
|
| 250 |
+
st.subheader("📄 Live Preview")
|
| 251 |
+
|
| 252 |
+
# --- The Dynamic Preview Logic ---
|
| 253 |
+
# This block re-runs on every widget interaction.
|
| 254 |
+
with st.container(border=True):
|
| 255 |
+
# Update temporary questionnaire with question stem
|
| 256 |
+
if "survey_options" in st.session_state:
|
| 257 |
+
survey_options = st.session_state.survey_options
|
| 258 |
+
else:
|
| 259 |
+
survey_options = None
|
| 260 |
+
|
| 261 |
+
if randomize_order_bool:
|
| 262 |
+
st.session_state.temporary_questionnaire.prepare_prompt(
|
| 263 |
+
question_stem=question_stem_input,
|
| 264 |
+
answer_options=survey_options,
|
| 265 |
+
randomized_item_order=randomize_order_bool,
|
| 266 |
+
)
|
| 267 |
+
st.session_state.base_questionnaire.prepare_prompt(
|
| 268 |
+
question_stem=question_stem_input,
|
| 269 |
+
answer_options=survey_options,
|
| 270 |
+
randomized_item_order=False,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if not randomize_order_bool:
|
| 274 |
+
st.session_state.temporary_questionnaire = st.session_state.base_questionnaire.duplicate()
|
| 275 |
+
|
| 276 |
+
# Update system prompt and main prompt for preview (apply to temporary_questionnaire)
|
| 277 |
+
st.session_state.temporary_questionnaire.system_prompt = new_system_prompt
|
| 278 |
+
st.session_state.temporary_questionnaire.prompt = new_prompt
|
| 279 |
+
current_system_prompt, current_prompt = st.session_state.temporary_questionnaire.get_prompt_for_questionnaire_type(QuestionnairePresentation.SEQUENTIAL)
|
| 280 |
+
current_system_prompt = current_system_prompt.replace("\n", " \n")
|
| 281 |
+
current_prompt = current_prompt.replace("\n", " \n")
|
| 282 |
+
st.write(current_system_prompt)
|
| 283 |
+
st.write(current_prompt)
|
| 284 |
+
|
| 285 |
+
st.divider()
|
| 286 |
+
|
| 287 |
+
if st.button("Update Prompt(s)", type="secondary", use_container_width=True):
|
| 288 |
+
if change_all_system:
|
| 289 |
+
for questionnaire in st.session_state.questionnaires:
|
| 290 |
+
questionnaire.system_prompt = new_system_prompt
|
| 291 |
+
else:
|
| 292 |
+
st.session_state.questionnaires[current_questionnaire_id].system_prompt = new_system_prompt
|
| 293 |
+
|
| 294 |
+
if change_all_questionnaire:
|
| 295 |
+
for questionnaire in st.session_state.questionnaires:
|
| 296 |
+
questionnaire.prompt = new_prompt
|
| 297 |
+
else:
|
| 298 |
+
st.session_state.questionnaires[current_questionnaire_id].prompt = new_prompt
|
| 299 |
+
st.success("Prompt(s) updated!")
|
| 300 |
+
|
| 301 |
+
if st.button("Confirm and Prepare Questionnaire", type="primary", use_container_width=True):
|
| 302 |
+
for questionnaire in st.session_state.questionnaires:
|
| 303 |
+
questionnaire.prepare_prompt(
|
| 304 |
+
question_stem=question_stem_input,
|
| 305 |
+
answer_options=survey_options,
|
| 306 |
+
randomized_item_order=randomize_order_bool,
|
| 307 |
+
)
|
| 308 |
+
st.success("Changed the prompts!")
|
| 309 |
+
st.switch_page("pages/02_Option_Prompt.py")
|
| 310 |
+
else:
|
| 311 |
+
st.warning("No data found. Please upload a CSV file on the 'Start Page' first.")
|
src/pages/02_Option_Prompt.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from qstn.prompt_builder import LLMPrompt, generate_likert_options
|
| 3 |
+
from qstn.utilities.prompt_templates import (
|
| 4 |
+
LIST_OPTIONS_DEFAULT,
|
| 5 |
+
SCALE_OPTIONS_DEFAULT,
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
from gui_elements.stateful_widget import StatefulWidgets
|
| 9 |
+
|
| 10 |
+
st.set_page_config(layout="wide")
|
| 11 |
+
st.title("Likert Scale Options Generator")
|
| 12 |
+
st.write(
|
| 13 |
+
"This interface allows you to configure and generate Likert scale answer options by adjusting the parameters below."
|
| 14 |
+
)
|
| 15 |
+
st.divider()
|
| 16 |
+
|
| 17 |
+
if "questionnaires" not in st.session_state:
|
| 18 |
+
st.error("You need to first upload a questionnaire and the population you want to survey.")
|
| 19 |
+
st.stop()
|
| 20 |
+
disabled = True
|
| 21 |
+
else:
|
| 22 |
+
disabled = False
|
| 23 |
+
|
| 24 |
+
#if 'answer_texts_input' not in st.session_state:
|
| 25 |
+
#st.session_state.answer_texts_input = "Strongly Disagree\nDisagree\nNeutral\nAgree\nStrongly Agree"
|
| 26 |
+
|
| 27 |
+
state = StatefulWidgets()
|
| 28 |
+
|
| 29 |
+
# Use a form to batch all inputs together
|
| 30 |
+
with st.container(border=True):
|
| 31 |
+
# --- Main Configuration ---
|
| 32 |
+
st.subheader("Main Configuration")
|
| 33 |
+
col1, col2, col3 = st.columns(3)
|
| 34 |
+
|
| 35 |
+
with col1:
|
| 36 |
+
n_options = state.create(
|
| 37 |
+
st.number_input,
|
| 38 |
+
"n_options",
|
| 39 |
+
"Number of Options (n)",
|
| 40 |
+
initial_value=5,
|
| 41 |
+
min_value=2,
|
| 42 |
+
step=1,
|
| 43 |
+
help="The total number of choices in the scale.",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
with col2:
|
| 47 |
+
idx_type = state.create(
|
| 48 |
+
st.selectbox,
|
| 49 |
+
"idx_type",
|
| 50 |
+
"Index Type",
|
| 51 |
+
initial_value="integer",
|
| 52 |
+
options=["integer", "char_low", "char_up"],
|
| 53 |
+
help="The type of index to use for the options.",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
with col3:
|
| 57 |
+
start_idx = state.create(
|
| 58 |
+
st.number_input,
|
| 59 |
+
"start_idx",
|
| 60 |
+
"Starting Index",
|
| 61 |
+
initial_value=1,
|
| 62 |
+
step=1,
|
| 63 |
+
help="The number to start counting from (e.g., 1).",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# --- Order and Structure Options ---
|
| 67 |
+
st.subheader("Ordering and Structure")
|
| 68 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 69 |
+
|
| 70 |
+
with col1:
|
| 71 |
+
only_from_to_scale = state.create(
|
| 72 |
+
st.checkbox,
|
| 73 |
+
"only_from_to_scale",
|
| 74 |
+
"From-To Scale Only",
|
| 75 |
+
initial_value=False,
|
| 76 |
+
help="If checked, only the first and last answer labels are display e.g. 1 Strongly Disagree to 5 Strongly agree.",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
with col2:
|
| 80 |
+
random_order = state.create(
|
| 81 |
+
st.checkbox,
|
| 82 |
+
"random_order",
|
| 83 |
+
"Random Order",
|
| 84 |
+
initial_value=False,
|
| 85 |
+
help="Randomize the order of options.",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
with col3:
|
| 89 |
+
reversed_order = state.create(
|
| 90 |
+
st.checkbox,
|
| 91 |
+
"reversed_order",
|
| 92 |
+
"Reversed Order",
|
| 93 |
+
initial_value=False,
|
| 94 |
+
help="Reverse the order of options.",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
with col4:
|
| 98 |
+
even_order = state.create(
|
| 99 |
+
st.checkbox,
|
| 100 |
+
"even_order",
|
| 101 |
+
"Even Order",
|
| 102 |
+
initial_value=False,
|
| 103 |
+
help="If there is an uneven number of answer texts, the middle section is automatically removed.",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# --- Answer Texts Input ---
|
| 107 |
+
st.subheader("Answer Texts")
|
| 108 |
+
|
| 109 |
+
answer_texts = state.create(
|
| 110 |
+
st.text_area,
|
| 111 |
+
"answer_texts",
|
| 112 |
+
"Enter Answer Texts (one per line)",
|
| 113 |
+
initial_value="Strongly Disagree\nDisagree\nNeutral\nAgree\nStrongly Agree",
|
| 114 |
+
height=150,
|
| 115 |
+
help="Enter the labels for each answer option.",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# --- Advanced Configuration ---
|
| 119 |
+
with st.expander("Advanced Configuration"):
|
| 120 |
+
options_separator = state.create(
|
| 121 |
+
st.text_input,
|
| 122 |
+
"options_separator",
|
| 123 |
+
"Options Separator",
|
| 124 |
+
initial_value=", ",
|
| 125 |
+
help="The character(s) used to separate options in the final string.",
|
| 126 |
+
)
|
| 127 |
+
list_prompt_template = state.create(
|
| 128 |
+
st.text_area,
|
| 129 |
+
"list_prompt_template",
|
| 130 |
+
"List Prompt Template",
|
| 131 |
+
initial_value=LIST_OPTIONS_DEFAULT,
|
| 132 |
+
height=100,
|
| 133 |
+
help="Write how the options should be presented to the model.",
|
| 134 |
+
)
|
| 135 |
+
scale_prompt_template = state.create(
|
| 136 |
+
st.text_area,
|
| 137 |
+
"scale_prompt_template",
|
| 138 |
+
"Scale Prompt Template",
|
| 139 |
+
initial_value=SCALE_OPTIONS_DEFAULT,
|
| 140 |
+
height=100,
|
| 141 |
+
help="Write how the options should be presented to the model.",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# The submit button for the form
|
| 146 |
+
submitted = st.button("Confirm and Generate Options", disabled=disabled, type="primary", use_container_width=True)
|
| 147 |
+
|
| 148 |
+
if st.button("Remove all options", use_container_width=True, icon="❌"):
|
| 149 |
+
st.session_state.survey_options = None
|
| 150 |
+
st.switch_page("pages/03_Inference_Setting.py")
|
| 151 |
+
|
| 152 |
+
# --- Processing and Output ---
|
| 153 |
+
if submitted:
|
| 154 |
+
#print("Session state answer texts "+ st.session_state.answer_texts_input)
|
| 155 |
+
#print(answer_texts_input)
|
| 156 |
+
# Convert the raw text area string into a list of strings.
|
| 157 |
+
answer_texts_list = [
|
| 158 |
+
text.strip() for text in answer_texts.split("\n") if text.strip()
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
# --- Input Validation ---
|
| 162 |
+
validation_ok = True
|
| 163 |
+
if only_from_to_scale and len(answer_texts_list) != 2:
|
| 164 |
+
st.error(
|
| 165 |
+
f"Error: When 'From-To Scale Only' is selected, you must provide exactly 2 answer texts. You provided {len(answer_texts_list)}."
|
| 166 |
+
)
|
| 167 |
+
validation_ok = False
|
| 168 |
+
|
| 169 |
+
if not only_from_to_scale and len(answer_texts_list) != n_options:
|
| 170 |
+
st.error(
|
| 171 |
+
f"Error: The number of answer texts ({len(answer_texts_list)}) must match the 'Number of Options' ({n_options})."
|
| 172 |
+
)
|
| 173 |
+
validation_ok = False
|
| 174 |
+
|
| 175 |
+
if reversed_order and random_order:
|
| 176 |
+
st.error(f"Error: Reversed Order and Random Order cannot both be true.")
|
| 177 |
+
validation_ok = False
|
| 178 |
+
|
| 179 |
+
if validation_ok:
|
| 180 |
+
survey_options = generate_likert_options(
|
| 181 |
+
n=n_options,
|
| 182 |
+
answer_texts=answer_texts_list,
|
| 183 |
+
only_from_to_scale=only_from_to_scale,
|
| 184 |
+
random_order=random_order,
|
| 185 |
+
reversed_order=reversed_order,
|
| 186 |
+
even_order=even_order,
|
| 187 |
+
start_idx=start_idx,
|
| 188 |
+
list_prompt_template=list_prompt_template,
|
| 189 |
+
scale_prompt_template=scale_prompt_template,
|
| 190 |
+
options_separator=options_separator,
|
| 191 |
+
idx_type=idx_type,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
st.session_state.survey_options = survey_options
|
| 195 |
+
st.switch_page("pages/03_Inference_Setting.py")
|
| 196 |
+
|
| 197 |
+
|
src/pages/03_Inference_Setting.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import json
|
| 3 |
+
from gui_elements.stateful_widget import StatefulWidgets
|
| 4 |
+
|
| 5 |
+
# --- Page Configuration ---
|
| 6 |
+
st.set_page_config(
|
| 7 |
+
page_title="Inference Settings",
|
| 8 |
+
layout="wide"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
st.title("AsyncOpenAI API Client & Inference Configurator")
|
| 12 |
+
st.markdown("Use the widgets below to configure the `AsyncOpenAI` client and the inference parameters for an API call. Advanced or less common options can be added as a JSON object.")
|
| 13 |
+
|
| 14 |
+
st.divider()
|
| 15 |
+
|
| 16 |
+
# --- Column Layout ---
|
| 17 |
+
col1, col2 = st.columns(2)
|
| 18 |
+
|
| 19 |
+
defaults = {
|
| 20 |
+
# Client Config
|
| 21 |
+
"api_key": "", "organization": "", "project": "", "base_url": "",
|
| 22 |
+
"timeout": 20, "max_retries": 2,
|
| 23 |
+
"advanced_client_params_str": '',
|
| 24 |
+
# Inference Config
|
| 25 |
+
"model_name": "", "temperature": 1.0, "max_tokens": 1024,
|
| 26 |
+
"top_p": 1.0, "seed": 42,
|
| 27 |
+
"advanced_inference_params_str": ''
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
for key, value in defaults.items():
|
| 31 |
+
if key not in st.session_state:
|
| 32 |
+
st.session_state[key] = value
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
state = StatefulWidgets()
|
| 36 |
+
|
| 37 |
+
# ==============================================================================
|
| 38 |
+
# COLUMN 1: OPENAI CLIENT CONFIGURATION
|
| 39 |
+
# ==============================================================================
|
| 40 |
+
with col1:
|
| 41 |
+
st.header("1. Client Configuration")
|
| 42 |
+
|
| 43 |
+
with col2:
|
| 44 |
+
st.header("2. Inference Configuration")
|
| 45 |
+
|
| 46 |
+
with col1:
|
| 47 |
+
with st.container(border=True):
|
| 48 |
+
st.subheader("Core Settings")
|
| 49 |
+
|
| 50 |
+
api_key = state.create(
|
| 51 |
+
st.text_input,
|
| 52 |
+
"api_key",
|
| 53 |
+
"API Key",
|
| 54 |
+
initial_value="",
|
| 55 |
+
type="password",
|
| 56 |
+
placeholder="sk-...",
|
| 57 |
+
help="Your OpenAI API key. It is handled securely by Streamlit."
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
organization = state.create(
|
| 61 |
+
st.text_input,
|
| 62 |
+
"organization",
|
| 63 |
+
"Organization ID",
|
| 64 |
+
initial_value="",
|
| 65 |
+
placeholder="org-...",
|
| 66 |
+
help="Optional identifier for your organization."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
project = state.create(
|
| 70 |
+
st.text_input,
|
| 71 |
+
"project",
|
| 72 |
+
"Project ID",
|
| 73 |
+
initial_value="",
|
| 74 |
+
placeholder="proj_...",
|
| 75 |
+
help="Optional identifier for your project."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
base_url = state.create(
|
| 79 |
+
st.text_input,
|
| 80 |
+
"base_url",
|
| 81 |
+
"Base URL",
|
| 82 |
+
initial_value="",
|
| 83 |
+
placeholder="https://api.openai.com/v1",
|
| 84 |
+
help="The base URL for the API. Leave empty for the default."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
timeout = state.create(
|
| 88 |
+
st.number_input,
|
| 89 |
+
"timeout",
|
| 90 |
+
"Timeout (seconds)",
|
| 91 |
+
initial_value=20,
|
| 92 |
+
min_value=1,
|
| 93 |
+
help="The timeout for API requests in seconds."
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
max_retries = state.create(
|
| 97 |
+
st.number_input,
|
| 98 |
+
"max_retries",
|
| 99 |
+
"Max Retries",
|
| 100 |
+
initial_value=2,
|
| 101 |
+
min_value=0,
|
| 102 |
+
help="The maximum number of times to retry a failed request."
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
with st.expander("Advanced Client Settings (JSON)"):
|
| 106 |
+
advanced_client_params_str = state.create(
|
| 107 |
+
st.text_area,
|
| 108 |
+
"advanced_client_params_str",
|
| 109 |
+
"JSON for other client parameters",
|
| 110 |
+
initial_value="",
|
| 111 |
+
placeholder='{\n "default_headers": {"X-Custom-Header": "value"}\n}',
|
| 112 |
+
height=150,
|
| 113 |
+
help='Enter any other client init parameters like "default_headers" or "default_query" as a valid JSON object.'
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# ==============================================================================
|
| 117 |
+
# COLUMN 2: INFERENCE PARAMETERS
|
| 118 |
+
# ==============================================================================
|
| 119 |
+
|
| 120 |
+
with col2:
|
| 121 |
+
with st.container(border=True):
|
| 122 |
+
st.subheader("Core Settings")
|
| 123 |
+
model_name = state.create(
|
| 124 |
+
st.text_input,
|
| 125 |
+
"model_name",
|
| 126 |
+
"Model Name",
|
| 127 |
+
#initial_value="meta-llama/Llama-3.1-70B-Instruct",
|
| 128 |
+
placeholder="meta-llama/Llama-3.1-70B-Instruct",
|
| 129 |
+
help="The model to use for the inference call."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
temperature = state.create(
|
| 133 |
+
st.slider,
|
| 134 |
+
"temperature",
|
| 135 |
+
"Temperature",
|
| 136 |
+
min_value=0.0,
|
| 137 |
+
max_value=2.0,
|
| 138 |
+
step=0.01,
|
| 139 |
+
initial_value=1.0,
|
| 140 |
+
help="Controls randomness. Lower values are more deterministic and less creative."
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
max_tokens = state.create(
|
| 144 |
+
st.number_input,
|
| 145 |
+
"max_tokens",
|
| 146 |
+
"Max Tokens",
|
| 147 |
+
initial_value=1024,
|
| 148 |
+
min_value=1,
|
| 149 |
+
help="The maximum number of tokens to generate in the completion."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
top_p = state.create(
|
| 153 |
+
st.slider,
|
| 154 |
+
"top_p",
|
| 155 |
+
"Top P",
|
| 156 |
+
min_value=0.0,
|
| 157 |
+
max_value=1.0,
|
| 158 |
+
step=0.01,
|
| 159 |
+
initial_value=1.0,
|
| 160 |
+
help="Controls nucleus sampling. The model considers tokens with top_p probability mass."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
seed = state.create(
|
| 164 |
+
st.number_input,
|
| 165 |
+
"seed",
|
| 166 |
+
"Seed",
|
| 167 |
+
initial_value=42,
|
| 168 |
+
min_value=0,
|
| 169 |
+
help="A specific seed for reproducibility of results."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
with st.expander("Advanced Inference Settings (JSON)"):
|
| 173 |
+
advanced_inference_params_str = state.create(
|
| 174 |
+
st.text_area,
|
| 175 |
+
"advanced_inference_params_str",
|
| 176 |
+
"JSON for other inference parameters",
|
| 177 |
+
initial_value="",
|
| 178 |
+
placeholder='{\n "stop": ["\\n", " Human:"],\n "presence_penalty": 0\n}',
|
| 179 |
+
height=150,
|
| 180 |
+
help='Enter any other valid inference parameters like "stop", "logit_bias", or "frequency_penalty" as a JSON object.'
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ==============================================================================
|
| 185 |
+
# GENERATION AND DISPLAY LOGIC
|
| 186 |
+
# ==============================================================================
|
| 187 |
+
st.divider()
|
| 188 |
+
|
| 189 |
+
if st.button("Generate Configuration & Code", type="primary", use_container_width=True):
|
| 190 |
+
# --- Process Client Config ---
|
| 191 |
+
client_config = {
|
| 192 |
+
"api_key": api_key
|
| 193 |
+
}
|
| 194 |
+
# Add optional string parameters if they are not empty
|
| 195 |
+
if organization: client_config["organization"] = organization
|
| 196 |
+
if project: client_config["project"] = project
|
| 197 |
+
if base_url: client_config["base_url"] = base_url
|
| 198 |
+
# Add numeric parameters
|
| 199 |
+
client_config["timeout"] = timeout
|
| 200 |
+
client_config["max_retries"] = max_retries
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
if advanced_client_params_str:
|
| 204 |
+
advanced_client_params = json.loads(advanced_client_params_str)
|
| 205 |
+
client_config.update(advanced_client_params)
|
| 206 |
+
except json.JSONDecodeError:
|
| 207 |
+
st.error("Invalid JSON detected in Advanced Client Settings. Please correct it.")
|
| 208 |
+
st.stop()
|
| 209 |
+
|
| 210 |
+
# --- Process Inference Config ---
|
| 211 |
+
inference_config = {
|
| 212 |
+
"model": model_name,
|
| 213 |
+
"temperature": temperature,
|
| 214 |
+
"max_tokens": max_tokens,
|
| 215 |
+
"top_p": top_p,
|
| 216 |
+
"seed": seed
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
if advanced_inference_params_str:
|
| 221 |
+
advanced_inference_params = json.loads(advanced_inference_params_str)
|
| 222 |
+
inference_config.update(advanced_inference_params)
|
| 223 |
+
except json.JSONDecodeError:
|
| 224 |
+
st.error("Invalid JSON detected in Advanced Inference Settings. Please correct it.")
|
| 225 |
+
st.stop()
|
| 226 |
+
|
| 227 |
+
st.session_state.client_config = client_config
|
| 228 |
+
st.session_state.inference_config = inference_config
|
| 229 |
+
|
| 230 |
+
st.success("Configuration generated successfully!")
|
src/pages/04_Final_Overview.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from gui_elements.paginator import paginator
|
| 3 |
+
from gui_elements.stateful_widget import StatefulWidgets
|
| 4 |
+
from gui_elements.output_manager import st_capture, TqdmToStreamlit
|
| 5 |
+
|
| 6 |
+
import io
|
| 7 |
+
import queue
|
| 8 |
+
import time
|
| 9 |
+
import threading
|
| 10 |
+
import asyncio
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
from contextlib import redirect_stderr, redirect_stdout
|
| 15 |
+
|
| 16 |
+
from qstn.parser.llm_answer_parser import raw_responses
|
| 17 |
+
from qstn.utilities.constants import QuestionnairePresentation
|
| 18 |
+
from qstn.utilities.utils import create_one_dataframe
|
| 19 |
+
from qstn.survey_manager import (
|
| 20 |
+
conduct_survey_sequential,
|
| 21 |
+
conduct_survey_battery,
|
| 22 |
+
conduct_survey_single_item,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from streamlit.runtime.scriptrunner import add_script_run_ctx
|
| 26 |
+
|
| 27 |
+
from openai import AsyncOpenAI
|
| 28 |
+
|
| 29 |
+
# Set OpenAI's API key and API base to use vLLM's API server.
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if "questionnaires" not in st.session_state:
|
| 33 |
+
st.error(
|
| 34 |
+
"You need to first upload a questionnaire and the population you want to survey."
|
| 35 |
+
)
|
| 36 |
+
st.stop()
|
| 37 |
+
disabled = True
|
| 38 |
+
else:
|
| 39 |
+
disabled = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@st.cache_data
|
| 43 |
+
def create_stateful_widget() -> StatefulWidgets:
|
| 44 |
+
return StatefulWidgets()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
state = create_stateful_widget()
|
| 48 |
+
|
| 49 |
+
current_index = paginator(st.session_state.questionnaires, "overview_page")
|
| 50 |
+
|
| 51 |
+
questionnaires = st.session_state.questionnaires[current_index]
|
| 52 |
+
|
| 53 |
+
col_llm, col_prompt_display = st.columns(2)
|
| 54 |
+
|
| 55 |
+
with col_llm:
|
| 56 |
+
st.subheader("⚙️ Inference Parameters")
|
| 57 |
+
|
| 58 |
+
with st.container(border=True):
|
| 59 |
+
st.subheader("Core Settings")
|
| 60 |
+
model_name = state.create(
|
| 61 |
+
st.text_input,
|
| 62 |
+
"model_name",
|
| 63 |
+
"Model Name",
|
| 64 |
+
# initial_value="meta-llama/Llama-3.1-70B-Instruct",
|
| 65 |
+
# placeholder="meta-llama/Llama-3.1-70B-Instruct",
|
| 66 |
+
disabled=True,
|
| 67 |
+
help="The model to use for the inference call.",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
temperature = state.create(
|
| 71 |
+
st.slider,
|
| 72 |
+
"temperature",
|
| 73 |
+
"Temperature",
|
| 74 |
+
min_value=0.0,
|
| 75 |
+
max_value=2.0,
|
| 76 |
+
step=0.01,
|
| 77 |
+
initial_value=1.0,
|
| 78 |
+
disabled=True,
|
| 79 |
+
help="Controls randomness. Lower values are more deterministic and less creative.",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
max_tokens = state.create(
|
| 83 |
+
st.number_input,
|
| 84 |
+
"max_tokens",
|
| 85 |
+
"Max Tokens",
|
| 86 |
+
initial_value=1024,
|
| 87 |
+
min_value=1,
|
| 88 |
+
disabled=True,
|
| 89 |
+
help="The maximum number of tokens to generate in the completion.",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
top_p = state.create(
|
| 93 |
+
st.slider,
|
| 94 |
+
"top_p",
|
| 95 |
+
"Top P",
|
| 96 |
+
min_value=0.0,
|
| 97 |
+
max_value=1.0,
|
| 98 |
+
step=0.01,
|
| 99 |
+
initial_value=1.0,
|
| 100 |
+
disabled=True,
|
| 101 |
+
help="Controls nucleus sampling. The model considers tokens with top_p probability mass.",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
seed = state.create(
|
| 105 |
+
st.number_input,
|
| 106 |
+
"seed",
|
| 107 |
+
"Seed",
|
| 108 |
+
initial_value=42,
|
| 109 |
+
min_value=0,
|
| 110 |
+
disabled=True,
|
| 111 |
+
help="A specific seed for reproducibility of results.",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
with st.expander("Advanced Inference Settings (JSON)"):
|
| 115 |
+
advanced_inference_params_str = state.create(
|
| 116 |
+
st.text_area,
|
| 117 |
+
"advanced_inference_params_str",
|
| 118 |
+
"JSON for other inference parameters",
|
| 119 |
+
initial_value="",
|
| 120 |
+
# placeholder='{\n "stop": ["\\n", " Human:"],\n "presence_penalty": 0\n}',
|
| 121 |
+
height=150,
|
| 122 |
+
disabled=True,
|
| 123 |
+
help='Enter any other valid inference parameters like "stop", "logit_bias", or "frequency_penalty" as a JSON object.',
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
with col_prompt_display:
|
| 128 |
+
st.subheader("📄 Live Preview")
|
| 129 |
+
|
| 130 |
+
# Survey method selector
|
| 131 |
+
survey_method_options = {
|
| 132 |
+
"Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM),
|
| 133 |
+
"Battery": ("battery", QuestionnairePresentation.BATTERY),
|
| 134 |
+
"Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
survey_method_display = state.create(
|
| 138 |
+
st.selectbox,
|
| 139 |
+
"survey_method",
|
| 140 |
+
"Questionnaire Method",
|
| 141 |
+
options=list(survey_method_options.keys()),
|
| 142 |
+
initial_value="Single item",
|
| 143 |
+
help="Choose how to conduct the questionnaire: Single item (one at a time), Battery (all questions together), or Sequential (with conversation history)."
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Get the method name and questionnaire type from selection
|
| 147 |
+
selected_method_name, selected_questionnaire_type = survey_method_options[survey_method_display]
|
| 148 |
+
|
| 149 |
+
with st.container(border=True):
|
| 150 |
+
current_system_prompt, current_prompt = questionnaires.get_prompt_for_questionnaire_type(selected_questionnaire_type)
|
| 151 |
+
current_system_prompt = current_system_prompt.replace("\n", " \n")
|
| 152 |
+
current_prompt = current_prompt.replace("\n", " \n")
|
| 153 |
+
st.write(current_system_prompt)
|
| 154 |
+
st.write(current_prompt)
|
| 155 |
+
|
| 156 |
+
model_name = state.create(
|
| 157 |
+
st.text_input,
|
| 158 |
+
"save_file",
|
| 159 |
+
"Save File",
|
| 160 |
+
# initial_value="meta-llama/Llama-3.1-70B-Instruct",
|
| 161 |
+
# placeholder="meta-llama/Llama-3.1-70B-Instruct",
|
| 162 |
+
help="The save file to write your results to. Should be a csv file.",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if st.button("Confirm and Run Questionnaire", type="primary", use_container_width=True):
|
| 167 |
+
st.write("Starting inference...")
|
| 168 |
+
|
| 169 |
+
openai_api_key = "EMPTY"
|
| 170 |
+
openai_api_base = "http://localhost:8000/v1"
|
| 171 |
+
|
| 172 |
+
client = AsyncOpenAI(**st.session_state.client_config)
|
| 173 |
+
|
| 174 |
+
inference_config = st.session_state.inference_config.copy()
|
| 175 |
+
|
| 176 |
+
model_name = inference_config.pop("model")
|
| 177 |
+
|
| 178 |
+
progress_text = st.empty()
|
| 179 |
+
|
| 180 |
+
log_queue = queue.Queue()
|
| 181 |
+
result_queue = queue.Queue()
|
| 182 |
+
|
| 183 |
+
class QueueWriter:
|
| 184 |
+
def __init__(self, q):
|
| 185 |
+
self.q = q
|
| 186 |
+
|
| 187 |
+
def write(self, message):
|
| 188 |
+
if message.strip():
|
| 189 |
+
self.q.put(message)
|
| 190 |
+
|
| 191 |
+
def flush(self):
|
| 192 |
+
# This function is needed to match the file-like object interface
|
| 193 |
+
# but we don't need to do anything here.
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
# Helper function for asyncronous runs
|
| 197 |
+
def run_async_in_thread(
|
| 198 |
+
result_q, client, questionnaires, model_name, survey_method_name, **inference_config
|
| 199 |
+
):
|
| 200 |
+
queue_writer = QueueWriter(log_queue)
|
| 201 |
+
|
| 202 |
+
# We need to redirect the output to a queue, as streamlit does not support multithreading
|
| 203 |
+
# API concurrency should be configurable in the GUI
|
| 204 |
+
try:
|
| 205 |
+
with redirect_stderr(queue_writer):
|
| 206 |
+
# Select the appropriate survey method based on user choice
|
| 207 |
+
if survey_method_name == "single_item":
|
| 208 |
+
survey_func = conduct_survey_single_item
|
| 209 |
+
elif survey_method_name == "battery":
|
| 210 |
+
survey_func = conduct_survey_battery
|
| 211 |
+
elif survey_method_name == "sequential":
|
| 212 |
+
survey_func = conduct_survey_sequential
|
| 213 |
+
else:
|
| 214 |
+
survey_func = conduct_survey_single_item # Default fallback
|
| 215 |
+
|
| 216 |
+
result = survey_func(
|
| 217 |
+
client,
|
| 218 |
+
llm_prompts=questionnaires,
|
| 219 |
+
client_model_name=model_name,
|
| 220 |
+
api_concurrency=100,
|
| 221 |
+
**inference_config,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
result = e
|
| 226 |
+
st.error(e)
|
| 227 |
+
finally:
|
| 228 |
+
result_q.put(result)
|
| 229 |
+
|
| 230 |
+
while not log_queue.empty():
|
| 231 |
+
log_queue.get()
|
| 232 |
+
while not result_queue.empty():
|
| 233 |
+
result_queue.get()
|
| 234 |
+
|
| 235 |
+
# Get the selected survey method
|
| 236 |
+
survey_method_display = st.session_state.get("survey_method", "Single item")
|
| 237 |
+
survey_method_options = {
|
| 238 |
+
"Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM),
|
| 239 |
+
"Battery": ("battery", QuestionnairePresentation.BATTERY),
|
| 240 |
+
"Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL),
|
| 241 |
+
}
|
| 242 |
+
selected_method_name, _ = survey_method_options.get(survey_method_display, ("single_item", QuestionnairePresentation.SINGLE_ITEM))
|
| 243 |
+
|
| 244 |
+
thread = threading.Thread(
|
| 245 |
+
target=run_async_in_thread,
|
| 246 |
+
args=(result_queue, client, st.session_state.questionnaires, model_name, selected_method_name),
|
| 247 |
+
kwargs=inference_config,
|
| 248 |
+
)
|
| 249 |
+
thread.start()
|
| 250 |
+
|
| 251 |
+
all_questions_placeholder = st.empty()
|
| 252 |
+
progress_placeholder = st.empty()
|
| 253 |
+
|
| 254 |
+
while thread.is_alive():
|
| 255 |
+
try:
|
| 256 |
+
# Here we can write directly to the UI, as it is the main thread
|
| 257 |
+
# TQDM uses carriage returns (\r) to animate in the console, we only show clear lines
|
| 258 |
+
log_message = log_queue.get_nowait()
|
| 259 |
+
# This is quite a hacky solution for now, we should adjust QSTN to make the messages clearly parsable.
|
| 260 |
+
if "[A" not in log_message and "Processing Prompts" not in log_message:
|
| 261 |
+
all_questions_placeholder.text(log_message.strip().replace("\r", ""))
|
| 262 |
+
|
| 263 |
+
elif "Processing Prompts" in log_message:
|
| 264 |
+
progress_placeholder.text(log_message.strip().replace("\r", ""))
|
| 265 |
+
|
| 266 |
+
except queue.Empty:
|
| 267 |
+
pass
|
| 268 |
+
time.sleep(0.1)
|
| 269 |
+
thread.join()
|
| 270 |
+
|
| 271 |
+
all_questions_placeholder.empty()
|
| 272 |
+
progress_placeholder.empty()
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
final_output = result_queue.get_nowait()
|
| 276 |
+
except queue.Empty:
|
| 277 |
+
st.error("Could not retrieve result from the asynchronous task.")
|
| 278 |
+
|
| 279 |
+
st.success("Finished inferencing! Saving results...")
|
| 280 |
+
|
| 281 |
+
responses = raw_responses(final_output)
|
| 282 |
+
|
| 283 |
+
df = create_one_dataframe(responses)
|
| 284 |
+
|
| 285 |
+
st.dataframe(df)
|
| 286 |
+
|
| 287 |
+
df.to_csv(st.session_state.save_file, index=False)
|
| 288 |
+
|
| 289 |
+
st.success(f"File saved to {st.session_state.save_file}!")
|
src/pages/__init__.py
ADDED
|
File without changes
|
src/pages/old_pages/_01_Basic_Prompt_Settings.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from qstn.llm_questionnaire import LLMQuestionnaire
|
| 3 |
+
from qstn.survey_manager import conduct_survey_single_item, conduct_survey_sequential, conduct_survey_battery, SurveyOptionGenerator, SurveyCreator
|
| 4 |
+
from qstn.utilities.constants import QuestionnaireType
|
| 5 |
+
from gui_elements.paginator import paginator
|
| 6 |
+
from gui_elements.stateful_widget import StatefulWidgets
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
#CONSTANTS FOR FIELDS
|
| 11 |
+
system_prompt_field = "System prompt"
|
| 12 |
+
prompt_field = "Prompt"
|
| 13 |
+
change_all_system_prompts_checkbox = "system_change_all"
|
| 14 |
+
change_all_prompts_checkbox = "prompts_change_all"
|
| 15 |
+
|
| 16 |
+
st.set_page_config(layout="wide")
|
| 17 |
+
st.title("Generate Prompt")
|
| 18 |
+
st.write(
|
| 19 |
+
"This interface allows you to inspect and change the system prompt and primary instructions for the model."
|
| 20 |
+
)
|
| 21 |
+
st.divider()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@st.cache_data
|
| 25 |
+
def create_stateful_widget() -> StatefulWidgets:
|
| 26 |
+
return StatefulWidgets()
|
| 27 |
+
|
| 28 |
+
state = create_stateful_widget()
|
| 29 |
+
|
| 30 |
+
#FOR DEBUGGING
|
| 31 |
+
# if "questionnaires" not in st.session_state:
|
| 32 |
+
# st.session_state.questionnaires = [SurveyCreator().from_path(survey_path="/home/maxi/Documents/SurveyGen/surveys/ANES.csv", questionnaire_path="/home/maxi/Documents/SurveyGen/surveys/ANES_PERSONAS.csv")]
|
| 33 |
+
if "questionnaires" not in st.session_state:
|
| 34 |
+
st.error("You need to first upload a questionnaire and the population you want to survey.")
|
| 35 |
+
st.stop()
|
| 36 |
+
|
| 37 |
+
if 'current_index' not in st.session_state:
|
| 38 |
+
st.session_state.current_index = 0
|
| 39 |
+
|
| 40 |
+
text_field_ids = [system_prompt_field, prompt_field]
|
| 41 |
+
|
| 42 |
+
current_questionnaire_id = paginator(st.session_state.questionnaires, "current_questionnaire_index_prompt")
|
| 43 |
+
|
| 44 |
+
st.divider()
|
| 45 |
+
|
| 46 |
+
if "questionnaires" in st.session_state and st.session_state.questionnaires is not None:
|
| 47 |
+
try:
|
| 48 |
+
questionnaire = st.session_state.questionnaires[current_questionnaire_id].duplicate()
|
| 49 |
+
except IndexError:
|
| 50 |
+
st.error("Index is out of range. Resetting to the first item.")
|
| 51 |
+
current_questionnaire_id = 0
|
| 52 |
+
questionnaire = st.session_state.questionnaires[current_questionnaire_id].duplicate()
|
| 53 |
+
|
| 54 |
+
#st.session_state.preview_questionnaire = questionnaire
|
| 55 |
+
|
| 56 |
+
col_options, col_prompt_display = st.columns(2)
|
| 57 |
+
|
| 58 |
+
with col_options:
|
| 59 |
+
st.subheader("⚙️ Configuration")
|
| 60 |
+
|
| 61 |
+
new_system_prompt = st.text_area(
|
| 62 |
+
label=system_prompt_field,
|
| 63 |
+
key=f"{system_prompt_field}{current_questionnaire_id}",
|
| 64 |
+
value=questionnaire.system_prompt,
|
| 65 |
+
help="The system prompt the model is prompted with."
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
change_all_system = state.create(
|
| 69 |
+
st.checkbox,
|
| 70 |
+
key=change_all_system_prompts_checkbox,
|
| 71 |
+
label="On update: change all System Prompts",
|
| 72 |
+
help="If this is ticked, all system prompts will be changed to this.",
|
| 73 |
+
initial_value=False
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
new_prompt = st.text_area(
|
| 77 |
+
label=prompt_field,
|
| 78 |
+
key=f"{prompt_field}{current_questionnaire_id}",
|
| 79 |
+
value=questionnaire.prompt,
|
| 80 |
+
help="Instructions that are given to the model before the questions."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
change_all_questionnaire = state.create(
|
| 84 |
+
st.checkbox,
|
| 85 |
+
key=change_all_prompts_checkbox,
|
| 86 |
+
label="On update: change all questionnaire instructions",
|
| 87 |
+
help="If this is ticked, all questionnaire instructions will be changed to this.",
|
| 88 |
+
initial_value=False
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Place the corresponding output in the second column
|
| 92 |
+
with col_prompt_display:
|
| 93 |
+
st.subheader("📄 Live Preview")
|
| 94 |
+
|
| 95 |
+
# --- The Dynamic Preview Logic ---
|
| 96 |
+
# This block re-runs on every widget interaction.
|
| 97 |
+
with st.container(border=True):
|
| 98 |
+
questionnaire.system_prompt = new_system_prompt
|
| 99 |
+
questionnaire.prompt = new_prompt
|
| 100 |
+
current_system_prompt, current_prompt = questionnaire.get_prompt_for_questionnaire_type(QuestionnaireType.SEQUENTIAL)
|
| 101 |
+
current_system_prompt = current_system_prompt.replace("\n", " \n")
|
| 102 |
+
current_prompt = current_prompt.replace("\n", " \n")
|
| 103 |
+
st.write(current_system_prompt)
|
| 104 |
+
st.write(current_prompt)
|
| 105 |
+
if st.button("Update Prompt(s)", type="secondary", use_container_width=True):
|
| 106 |
+
if change_all_system:
|
| 107 |
+
for questionnaire in st.session_state.questionnaires:
|
| 108 |
+
questionnaire.system_prompt = new_system_prompt
|
| 109 |
+
else:
|
| 110 |
+
st.session_state.questionnaires[current_questionnaire_id].system_prompt = new_system_prompt
|
| 111 |
+
|
| 112 |
+
if change_all_questionnaire:
|
| 113 |
+
for questionnaire in st.session_state.questionnaires:
|
| 114 |
+
questionnaire.prompt = new_prompt
|
| 115 |
+
else:
|
| 116 |
+
st.session_state.questionnaires[current_questionnaire_id].prompt = new_prompt
|
| 117 |
+
st.success("Prompt(s) updated!")
|
| 118 |
+
|
| 119 |
+
if st.button("Confirm Base Prompt", type="primary", use_container_width=True):
|
| 120 |
+
st.switch_page("pages/02_Option_Prompt.py")
|
| 121 |
+
else:
|
| 122 |
+
st.warning("No data found. Please upload a CSV file on the 'Start Page' first.")
|
| 123 |
+
|
| 124 |
+
|
src/pages/old_pages/_03_Prepare_Prompts.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from qstn.survey_manager import SurveyOptionGenerator
|
| 3 |
+
from qstn.llm_questionnaire import LLMQuestionnaire
|
| 4 |
+
from qstn.utilities.constants import QuestionnaireType
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from gui_elements.paginator import paginator
|
| 8 |
+
|
| 9 |
+
# CONSTANTS FOR FIELDS
|
| 10 |
+
question_stem_field = "Question Stem"
|
| 11 |
+
randomize_order_tick = "Randomize the order of items"
|
| 12 |
+
|
| 13 |
+
field_ids = [question_stem_field, randomize_order_tick]
|
| 14 |
+
|
| 15 |
+
st.title("Questions Preparation")
|
| 16 |
+
st.write(
|
| 17 |
+
"This interface allows you configure how the questions are prompted to the LLM. These options are applied to every questionnaire in your survey."
|
| 18 |
+
)
|
| 19 |
+
st.page_link("pages/02_Option_Prompt.py", label="Click here to adjust the answer options.")
|
| 20 |
+
st.divider()
|
| 21 |
+
|
| 22 |
+
#current_questionnaire_id = paginator(st.session_state.questionnaires, "current_questionnaire_index_prepare")
|
| 23 |
+
|
| 24 |
+
if "questionnaires" not in st.session_state:
|
| 25 |
+
st.error("You need to first upload a questionnaire and the population you want to survey.")
|
| 26 |
+
st.stop()
|
| 27 |
+
disabled = True
|
| 28 |
+
else:
|
| 29 |
+
disabled = False
|
| 30 |
+
if not "temporary_questionnaire" in st.session_state:
|
| 31 |
+
st.session_state.temporary_questionnaire = st.session_state.questionnaires[0].duplicate()
|
| 32 |
+
|
| 33 |
+
#print(st.session_state.temporary_questionnaire._questions)
|
| 34 |
+
|
| 35 |
+
if not "base_questionnaire" in st.session_state:
|
| 36 |
+
st.session_state.base_questionnaire = st.session_state.temporary_questionnaire.duplicate()
|
| 37 |
+
|
| 38 |
+
def process_inputs(input: Any, field_id: str) -> str:
|
| 39 |
+
if "survey_options" in st.session_state:
|
| 40 |
+
survey_options = st.session_state.survey_options
|
| 41 |
+
else:
|
| 42 |
+
survey_options = None
|
| 43 |
+
|
| 44 |
+
if field_id == question_stem_field:
|
| 45 |
+
LLMQuestionnaire.prepare_questionnaire
|
| 46 |
+
st.session_state.temporary_questionnaire.prepare_questionnaire(
|
| 47 |
+
question_stem=input,
|
| 48 |
+
answer_options=survey_options,
|
| 49 |
+
randomized_item_order=randomize_order_bool,
|
| 50 |
+
)
|
| 51 |
+
st.session_state.base_questionnaire.prepare_questionnaire(
|
| 52 |
+
question_stem=input,
|
| 53 |
+
answer_options=survey_options,
|
| 54 |
+
randomized_item_order=randomize_order_bool,
|
| 55 |
+
)
|
| 56 |
+
# elif field_id == global_options_tick:
|
| 57 |
+
# option_behavior = input == "Give instruction in the beginning"
|
| 58 |
+
# st.session_state.temporary_questionnaire.prepare_questionnaire(
|
| 59 |
+
# question_stem=question_stem_input,
|
| 60 |
+
# answer_options=survey_options,
|
| 61 |
+
# global_options=option_behavior,
|
| 62 |
+
# randomized_item_order=randomize_order_bool,
|
| 63 |
+
# )
|
| 64 |
+
# st.session_state.base_questionnaire.prepare_questionnaire(
|
| 65 |
+
# question_stem=question_stem_input,
|
| 66 |
+
# answer_options=survey_options,
|
| 67 |
+
# global_options=option_behavior,
|
| 68 |
+
# randomized_item_order=False,
|
| 69 |
+
# )
|
| 70 |
+
elif field_id == randomize_order_tick:
|
| 71 |
+
if input == True:
|
| 72 |
+
st.session_state.temporary_questionnaire.prepare_questionnaire(
|
| 73 |
+
question_stem=question_stem_input,
|
| 74 |
+
answer_options=survey_options,
|
| 75 |
+
randomized_item_order=input,
|
| 76 |
+
)
|
| 77 |
+
else:
|
| 78 |
+
st.session_state.temporary_questionnaire = st.session_state.base_questionnaire.duplicate()
|
| 79 |
+
|
| 80 |
+
def handle_change(field_id: str):
|
| 81 |
+
"""
|
| 82 |
+
This single callback handles changes from any text field.
|
| 83 |
+
It reads the input from session state using the unique key,
|
| 84 |
+
processes it, and saves the output to session state.
|
| 85 |
+
"""
|
| 86 |
+
input_key = f"input_{field_id}"
|
| 87 |
+
|
| 88 |
+
with st.spinner(f"Processing {field_id}..."):
|
| 89 |
+
# time.sleep(0.5) # Simulate work
|
| 90 |
+
process_inputs(st.session_state[input_key], field_id)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
col1, col2 = st.columns(2, gap="large")
|
| 94 |
+
|
| 95 |
+
with col1:
|
| 96 |
+
st.subheader("⚙️ Configuration")
|
| 97 |
+
|
| 98 |
+
for field_id in field_ids:
|
| 99 |
+
input_key = f"input_{field_id}"
|
| 100 |
+
if not input_key in st.session_state:
|
| 101 |
+
if field_id == question_stem_field:
|
| 102 |
+
st.session_state[input_key] = st.session_state.temporary_questionnaire ._questions[0].question_stem
|
| 103 |
+
if field_id == randomize_order_tick:
|
| 104 |
+
st.session_state[input_key] = False
|
| 105 |
+
|
| 106 |
+
# --- Input Widgets (No Form) ---
|
| 107 |
+
question_stem_input = st.text_area(
|
| 108 |
+
question_stem_field,
|
| 109 |
+
key=f"input_{question_stem_field}",
|
| 110 |
+
# placeholder="e.g., How would you rate the following aspects of our service?",
|
| 111 |
+
#on_change=handle_change,
|
| 112 |
+
kwargs={'field_id': question_stem_field},
|
| 113 |
+
height=100,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# option_behavior = st.radio(
|
| 117 |
+
# global_options_tick,
|
| 118 |
+
# key=f"input_{global_options_tick}",
|
| 119 |
+
# options=[
|
| 120 |
+
# "Give instruction in the beginning",
|
| 121 |
+
# "Give options after each question",
|
| 122 |
+
# ],
|
| 123 |
+
# index=0,
|
| 124 |
+
# #on_change=handle_change,
|
| 125 |
+
# kwargs={'field_id': global_options_tick},
|
| 126 |
+
# help="Choose how answer options are applied.",
|
| 127 |
+
# )
|
| 128 |
+
|
| 129 |
+
randomize_order_bool = st.checkbox(
|
| 130 |
+
randomize_order_tick,
|
| 131 |
+
key=f"input_{randomize_order_tick}",
|
| 132 |
+
value=False,
|
| 133 |
+
#on_change=handle_change,
|
| 134 |
+
kwargs={'field_id': randomize_order_tick}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if "survey_options" in st.session_state:
|
| 138 |
+
survey_options = st.session_state.survey_options
|
| 139 |
+
else:
|
| 140 |
+
survey_options = None
|
| 141 |
+
|
| 142 |
+
if randomize_order_bool:
|
| 143 |
+
st.session_state.temporary_questionnaire.prepare_questionnaire(
|
| 144 |
+
question_stem=question_stem_input,
|
| 145 |
+
answer_options=survey_options,
|
| 146 |
+
randomized_item_order=randomize_order_bool,
|
| 147 |
+
)
|
| 148 |
+
st.session_state.base_questionnaire.prepare_questionnaire(
|
| 149 |
+
question_stem=question_stem_input,
|
| 150 |
+
answer_options=survey_options,
|
| 151 |
+
randomized_item_order=False,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if not randomize_order_bool:
|
| 155 |
+
st.session_state.temporary_questionnaire = st.session_state.base_questionnaire.duplicate()
|
| 156 |
+
|
| 157 |
+
with col2:
|
| 158 |
+
st.subheader("📄 Live Preview")
|
| 159 |
+
|
| 160 |
+
#@Ahmed All of these could be a resusable function (it is used on almost all pages) Maybe split up the container in system prompt/ prompt
|
| 161 |
+
with st.container(border=True):
|
| 162 |
+
system_prompt, current_prompt = st.session_state.temporary_questionnaire.get_prompt_for_questionnaire_type(QuestionnaireType.BATTERY)
|
| 163 |
+
# markdown newlines
|
| 164 |
+
system_prompt = system_prompt.replace("\n", " \n")
|
| 165 |
+
current_prompt = current_prompt.replace("\n", " \n")
|
| 166 |
+
st.write(system_prompt)
|
| 167 |
+
st.write(current_prompt)
|
| 168 |
+
|
| 169 |
+
st.divider()
|
| 170 |
+
|
| 171 |
+
if st.button("Confirm and Prepare Questionnaire", type="primary", use_container_width=True):
|
| 172 |
+
for questionnaire in st.session_state.questionnaires:
|
| 173 |
+
questionnaire.prepare_questionnaire(
|
| 174 |
+
question_stem=question_stem_input,
|
| 175 |
+
answer_options=survey_options,
|
| 176 |
+
randomized_item_order=randomize_order_bool,
|
| 177 |
+
)
|
| 178 |
+
st.success("Changed the prompts!")
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/surveyGenGui.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: surveyGenGui
|
| 3 |
+
Version: 0.0.1
|
src/surveyGenGui.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pyproject.toml
|
| 2 |
+
src/Start_Page.py
|
| 3 |
+
src/__init__.py
|
| 4 |
+
src/gui_elements/__init__.py
|
| 5 |
+
src/gui_elements/output_manager.py
|
| 6 |
+
src/gui_elements/paginator.py
|
| 7 |
+
src/gui_elements/stateful_widget.py
|
| 8 |
+
src/pages/01_Basic_Prompt_Settings.py
|
| 9 |
+
src/pages/02_Option_Prompt.py
|
| 10 |
+
src/pages/03_Prepare_Prompts.py
|
| 11 |
+
src/pages/04_Inference_Setting.py
|
| 12 |
+
src/pages/05_Final_Overview.py
|
| 13 |
+
src/pages/__init__.py
|
| 14 |
+
src/surveyGenGui.egg-info/PKG-INFO
|
| 15 |
+
src/surveyGenGui.egg-info/SOURCES.txt
|
| 16 |
+
src/surveyGenGui.egg-info/dependency_links.txt
|
| 17 |
+
src/surveyGenGui.egg-info/top_level.txt
|
src/surveyGenGui.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
src/surveyGenGui.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Start_Page
|
| 2 |
+
__init__
|
| 3 |
+
gui_elements
|
| 4 |
+
pages
|