Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from gui_elements.paginator import paginator | |
| from gui_elements.stateful_widget import StatefulWidgets | |
| from gui_elements.output_manager import st_capture, TqdmToStreamlit | |
| import io | |
| import queue | |
| import time | |
| import threading | |
| import asyncio | |
| import pandas as pd | |
| import logging | |
| from contextlib import redirect_stderr, redirect_stdout | |
| from qstn.parser.llm_answer_parser import raw_responses, parse_json, parse_json_battery | |
| from qstn.utilities.constants import QuestionnairePresentation | |
| from qstn.utilities.utils import create_one_dataframe | |
| from qstn.inference.response_generation import JSONResponseGenerationMethod | |
| from qstn.survey_manager import ( | |
| conduct_survey_sequential, | |
| conduct_survey_battery, | |
| conduct_survey_single_item, | |
| ) | |
| from streamlit.runtime.scriptrunner import add_script_run_ctx | |
| from openai import AsyncOpenAI | |
| # Set OpenAI's API key and API base to use vLLM's API server. | |
| if "questionnaires" not in st.session_state: | |
| st.error( | |
| "You need to first upload a questionnaire and the population you want to survey." | |
| ) | |
| st.stop() | |
| disabled = True | |
| else: | |
| disabled = False | |
| def create_stateful_widget() -> StatefulWidgets: | |
| return StatefulWidgets() | |
| state = create_stateful_widget() | |
| current_index = paginator(st.session_state.questionnaires, "overview_page") | |
| questionnaires = st.session_state.questionnaires[current_index] | |
| col_llm, col_prompt_display = st.columns(2) | |
| with col_llm: | |
| st.subheader("⚙️ Inference Parameters") | |
| with st.container(border=True): | |
| st.subheader("Core Settings") | |
| model_name = state.create( | |
| st.text_input, | |
| "model_name", | |
| "Model Name", | |
| disabled=True, | |
| help="The model to use for the inference call.", | |
| ) | |
| temperature = state.create( | |
| st.slider, | |
| "temperature", | |
| "Temperature", | |
| min_value=0.0, | |
| max_value=2.0, | |
| step=0.01, | |
| initial_value=1.0, | |
| disabled=True, | |
| help="Controls randomness. Lower values are more deterministic and less creative.", | |
| ) | |
| max_tokens = state.create( | |
| st.number_input, | |
| "max_tokens", | |
| "Max Tokens", | |
| initial_value=1024, | |
| min_value=1, | |
| disabled=True, | |
| help="The maximum number of tokens to generate in the completion.", | |
| ) | |
| top_p = state.create( | |
| st.slider, | |
| "top_p", | |
| "Top P", | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.01, | |
| initial_value=1.0, | |
| disabled=True, | |
| help="Controls nucleus sampling. The model considers tokens with top_p probability mass.", | |
| ) | |
| seed = state.create( | |
| st.number_input, | |
| "seed", | |
| "Seed", | |
| initial_value=42, | |
| min_value=0, | |
| disabled=True, | |
| help="A specific seed for reproducibility of results.", | |
| ) | |
| with st.expander("Advanced Inference Settings (JSON)"): | |
| advanced_inference_params_str = state.create( | |
| st.text_area, | |
| "advanced_inference_params_str", | |
| "JSON for other inference parameters", | |
| initial_value="", | |
| height=150, | |
| disabled=True, | |
| help='Enter any other valid inference parameters like "stop", "logit_bias", or "frequency_penalty" as a JSON object.', | |
| ) | |
| with col_prompt_display: | |
| st.subheader("📄 Live Preview") | |
| # Survey method selector | |
| survey_method_options = { | |
| "Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM), | |
| "Battery": ("battery", QuestionnairePresentation.BATTERY), | |
| "Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL), | |
| } | |
| survey_method_display = state.create( | |
| st.selectbox, | |
| "survey_method", | |
| "Questionnaire Method", | |
| options=list(survey_method_options.keys()), | |
| initial_value="Single item", | |
| help="Choose how to conduct the questionnaire: Single item (one at a time), Battery (all questions together), or Sequential (with conversation history)." | |
| ) | |
| # Get the method name and questionnaire type from selection | |
| selected_method_name, selected_questionnaire_type = survey_method_options[survey_method_display] | |
| with st.container(border=True): | |
| # Add a visible questionnaire identifier to force Streamlit to update when switching | |
| # This ensures the preview updates even when "change all" makes content identical | |
| st.caption(f"📋 Questionnaire {current_index + 1} of {len(st.session_state.questionnaires)}") | |
| # For single item mode, show multiple previews (up to 3 items) | |
| if selected_questionnaire_type == QuestionnairePresentation.SINGLE_ITEM: | |
| num_questions = len(questionnaires._questions) | |
| num_previews = min(3, num_questions) # Show up to 3 previews | |
| if num_previews > 1: | |
| st.write(f"**Preview of first {num_previews} items:**") | |
| else: | |
| st.write("**Preview:**") | |
| for i in range(num_previews): | |
| if num_previews > 1: | |
| st.write(f"**Item {i+1}:**") | |
| current_system_prompt, current_prompt = questionnaires.get_prompt_for_questionnaire_type( | |
| selected_questionnaire_type, | |
| item_id=i | |
| ) | |
| current_system_prompt = current_system_prompt.replace("\n", " \n") | |
| current_prompt = current_prompt.replace("\n", " \n") | |
| # Use a unique container key for each questionnaire/item to force Streamlit to update when switching | |
| # This ensures the preview updates even when content is identical | |
| with st.container(key=f"preview_container_{current_index}_{i}"): | |
| st.markdown(current_system_prompt) | |
| st.write(current_prompt) | |
| # Add separator between items (except for the last one) | |
| if i < num_previews - 1: | |
| st.divider() | |
| else: | |
| # For battery and sequential, show single preview as before | |
| current_system_prompt, current_prompt = questionnaires.get_prompt_for_questionnaire_type(selected_questionnaire_type) | |
| current_system_prompt = current_system_prompt.replace("\n", " \n") | |
| current_prompt = current_prompt.replace("\n", " \n") | |
| # Use a unique container key for each questionnaire to force Streamlit to update when switching | |
| # This ensures the preview updates even when content is identical | |
| with st.container(key=f"preview_container_{current_index}"): | |
| st.markdown(current_system_prompt) | |
| st.write(current_prompt) | |
| if st.button("Confirm and Run Questionnaire", type="primary", use_container_width=True): | |
| st.write("Starting inference...") | |
| openai_api_key = "EMPTY" | |
| openai_api_base = "http://localhost:8000/v1" | |
| client = AsyncOpenAI(**st.session_state.client_config) | |
| inference_config = st.session_state.inference_config.copy() | |
| model_name = inference_config.pop("model") | |
| progress_text = st.empty() | |
| log_queue = queue.Queue() | |
| result_queue = queue.Queue() | |
| class QueueWriter: | |
| def __init__(self, q): | |
| self.q = q | |
| def write(self, message): | |
| if message.strip(): | |
| self.q.put(message) | |
| def flush(self): | |
| # This function is needed to match the file-like object interface | |
| # but we don't need to do anything here. | |
| pass | |
| # Helper function for asyncronous runs | |
| def run_async_in_thread( | |
| result_q, client, questionnaires, model_name, survey_method_name, **inference_config | |
| ): | |
| queue_writer = QueueWriter(log_queue) | |
| # We need to redirect the output to a queue, as streamlit does not support multithreading | |
| # API concurrency should be configurable in the GUI | |
| try: | |
| with redirect_stderr(queue_writer): | |
| # Select the appropriate survey method based on user choice | |
| if survey_method_name == "single_item": | |
| survey_func = conduct_survey_single_item | |
| elif survey_method_name == "battery": | |
| survey_func = conduct_survey_battery | |
| elif survey_method_name == "sequential": | |
| survey_func = conduct_survey_sequential | |
| else: | |
| survey_func = conduct_survey_single_item # Default fallback | |
| result = survey_func( | |
| client, | |
| llm_prompts=questionnaires, | |
| client_model_name=model_name, | |
| api_concurrency=100, | |
| **inference_config, | |
| ) | |
| except Exception as e: | |
| result = e | |
| st.error(e) | |
| finally: | |
| result_q.put(result) | |
| while not log_queue.empty(): | |
| log_queue.get() | |
| while not result_queue.empty(): | |
| result_queue.get() | |
| # Get the selected survey method | |
| survey_method_display = st.session_state.get("survey_method", "Single item") | |
| survey_method_options = { | |
| "Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM), | |
| "Battery": ("battery", QuestionnairePresentation.BATTERY), | |
| "Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL), | |
| } | |
| selected_method_name, _ = survey_method_options.get(survey_method_display, ("single_item", QuestionnairePresentation.SINGLE_ITEM)) | |
| thread = threading.Thread( | |
| target=run_async_in_thread, | |
| args=(result_queue, client, st.session_state.questionnaires, model_name, selected_method_name), | |
| kwargs=inference_config, | |
| ) | |
| thread.start() | |
| all_questions_placeholder = st.empty() | |
| progress_placeholder = st.empty() | |
| while thread.is_alive(): | |
| try: | |
| # Here we can write directly to the UI, as it is the main thread | |
| # TQDM uses carriage returns (\r) to animate in the console, we only show clear lines | |
| log_message = log_queue.get_nowait() | |
| # This is quite a hacky solution for now, we should adjust QSTN to make the messages clearly parsable. | |
| if "[A" not in log_message and "Processing Prompts" not in log_message: | |
| all_questions_placeholder.text(log_message.strip().replace("\r", "")) | |
| elif "Processing Prompts" in log_message: | |
| progress_placeholder.text(log_message.strip().replace("\r", "")) | |
| except queue.Empty: | |
| pass | |
| time.sleep(0.1) | |
| thread.join() | |
| all_questions_placeholder.empty() | |
| progress_placeholder.empty() | |
| try: | |
| final_output = result_queue.get_nowait() | |
| except queue.Empty: | |
| st.error("Could not retrieve result from the asynchronous task.") | |
| st.success("Finished inferencing!") | |
| # Check if any questionnaire uses JSON response generation methods | |
| has_json_rgm = False | |
| if isinstance(final_output, list) and len(final_output) > 0: | |
| for result in final_output: | |
| if hasattr(result, 'questionnaire') and hasattr(result.questionnaire, '_questions'): | |
| for question in result.questionnaire._questions: | |
| if (hasattr(question, 'answer_options') and | |
| question.answer_options and | |
| hasattr(question.answer_options, 'response_generation_method') and | |
| question.answer_options.response_generation_method): | |
| rgm = question.answer_options.response_generation_method | |
| if isinstance(rgm, JSONResponseGenerationMethod): | |
| has_json_rgm = True | |
| break | |
| if has_json_rgm: | |
| break | |
| # Use appropriate parser based on response generation method | |
| if has_json_rgm: | |
| # Check survey method to use correct parser | |
| survey_method_display = st.session_state.get("survey_method", "Single item") | |
| if survey_method_display == "Battery": | |
| responses = parse_json_battery(final_output) | |
| else: | |
| responses = parse_json(final_output) | |
| else: | |
| responses = raw_responses(final_output) | |
| df = create_one_dataframe(responses) | |
| # Store the dataframe in session state for saving later | |
| st.session_state.results_dataframe = df | |
| st.session_state.inference_completed = True | |
| st.dataframe(df) | |
| # Show save button if inference is completed | |
| if "inference_completed" in st.session_state and st.session_state.inference_completed: | |
| st.divider() | |
| st.subheader("💾 Save Results") | |
| # Text input for filename | |
| if "save_filename" not in st.session_state: | |
| st.session_state.save_filename = "questionnaire_results.csv" | |
| save_filename = st.text_input( | |
| "Save File", | |
| value=st.session_state.save_filename, | |
| key="save_filename_input", | |
| help="Enter the filename for the results. Should be a CSV file (e.g., results.csv)." | |
| ) | |
| # Ensure filename ends with .csv | |
| if save_filename and not save_filename.endswith('.csv'): | |
| save_filename = save_filename + '.csv' | |
| # Convert dataframe to CSV string for download | |
| csv = st.session_state.results_dataframe.to_csv(index=False) | |
| st.download_button( | |
| label="Save Results", | |
| data=csv, | |
| file_name=save_filename if save_filename else "questionnaire_results.csv", | |
| mime="text/csv", | |
| type="primary", | |
| use_container_width=True, | |
| help="Click to save the results to your computer. You can choose the directory and filename in the save dialog." | |
| ) | |