File size: 11,974 Bytes
448ed98
 
 
 
 
 
 
 
 
 
5bcad58
448ed98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bcad58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448ed98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bcad58
448ed98
 
 
 
 
5bcad58
 
 
448ed98
5bcad58
448ed98
5bcad58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import streamlit as st
from gui_elements.paginator import paginator
from gui_elements.stateful_widget import StatefulWidgets
from gui_elements.output_manager import st_capture, TqdmToStreamlit

import io
import queue
import time
import threading
import asyncio
import pandas as pd

import logging

from contextlib import redirect_stderr, redirect_stdout

from qstn.parser.llm_answer_parser import raw_responses
from qstn.utilities.constants import QuestionnairePresentation
from qstn.utilities.utils import create_one_dataframe
from qstn.survey_manager import (
    conduct_survey_sequential,
    conduct_survey_battery,
    conduct_survey_single_item,
)

from streamlit.runtime.scriptrunner import add_script_run_ctx

from openai import AsyncOpenAI

# Set OpenAI's API key and API base to use vLLM's API server.


if "questionnaires" not in st.session_state:
    st.error(
        "You need to first upload a questionnaire and the population you want to survey."
    )
    st.stop()
    disabled = True
else:
    disabled = False


@st.cache_data
def create_stateful_widget() -> StatefulWidgets:
    return StatefulWidgets()


state = create_stateful_widget()

current_index = paginator(st.session_state.questionnaires, "overview_page")

questionnaires = st.session_state.questionnaires[current_index]

col_llm, col_prompt_display = st.columns(2)

with col_llm:
    st.subheader("⚙️ Inference Parameters")

    with st.container(border=True):
        st.subheader("Core Settings")
        model_name = state.create(
            st.text_input,
            "model_name",
            "Model Name",
            # initial_value="meta-llama/Llama-3.1-70B-Instruct",
            # placeholder="meta-llama/Llama-3.1-70B-Instruct",
            disabled=True,
            help="The model to use for the inference call.",
        )

        temperature = state.create(
            st.slider,
            "temperature",
            "Temperature",
            min_value=0.0,
            max_value=2.0,
            step=0.01,
            initial_value=1.0,
            disabled=True,
            help="Controls randomness. Lower values are more deterministic and less creative.",
        )

        max_tokens = state.create(
            st.number_input,
            "max_tokens",
            "Max Tokens",
            initial_value=1024,
            min_value=1,
            disabled=True,
            help="The maximum number of tokens to generate in the completion.",
        )

        top_p = state.create(
            st.slider,
            "top_p",
            "Top P",
            min_value=0.0,
            max_value=1.0,
            step=0.01,
            initial_value=1.0,
            disabled=True,
            help="Controls nucleus sampling. The model considers tokens with top_p probability mass.",
        )

        seed = state.create(
            st.number_input,
            "seed",
            "Seed",
            initial_value=42,
            min_value=0,
            disabled=True,
            help="A specific seed for reproducibility of results.",
        )

        with st.expander("Advanced Inference Settings (JSON)"):
            advanced_inference_params_str = state.create(
                st.text_area,
                "advanced_inference_params_str",
                "JSON for other inference parameters",
                initial_value="",
                # placeholder='{\n  "stop": ["\\n", " Human:"],\n  "presence_penalty": 0\n}',
                height=150,
                disabled=True,
                help='Enter any other valid inference parameters like "stop", "logit_bias", or "frequency_penalty" as a JSON object.',
            )


with col_prompt_display:
    st.subheader("📄 Live Preview")
    
    # Survey method selector
    survey_method_options = {
        "Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM),
        "Battery": ("battery", QuestionnairePresentation.BATTERY),
        "Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL),
    }
    
    survey_method_display = state.create(
        st.selectbox,
        "survey_method",
        "Questionnaire Method",
        options=list(survey_method_options.keys()),
        initial_value="Single item",
        help="Choose how to conduct the questionnaire: Single item (one at a time), Battery (all questions together), or Sequential (with conversation history)."
    )
    
    # Get the method name and questionnaire type from selection
    selected_method_name, selected_questionnaire_type = survey_method_options[survey_method_display]

    with st.container(border=True):
        # For single item mode, show multiple previews (up to 3 items)
        if selected_questionnaire_type == QuestionnairePresentation.SINGLE_ITEM:
            num_questions = len(questionnaires._questions)
            num_previews = min(3, num_questions)  # Show up to 3 previews
            
            if num_previews > 1:
                st.write(f"**Preview of first {num_previews} items:**")
            else:
                st.write("**Preview:**")
            
            for i in range(num_previews):
                if num_previews > 1:
                    st.write(f"**Item {i+1}:**")
                
                current_system_prompt, current_prompt = questionnaires.get_prompt_for_questionnaire_type(
                    selected_questionnaire_type, 
                    item_id=i
                )
                current_system_prompt = current_system_prompt.replace("\n", "  \n")
                current_prompt = current_prompt.replace("\n", "  \n")
                st.write(current_system_prompt)
                st.write(current_prompt)
                
                # Add separator between items (except for the last one)
                if i < num_previews - 1:
                    st.divider()
        else:
            # For battery and sequential, show single preview as before
            current_system_prompt, current_prompt = questionnaires.get_prompt_for_questionnaire_type(selected_questionnaire_type)
            current_system_prompt = current_system_prompt.replace("\n", "  \n")
            current_prompt = current_prompt.replace("\n", "  \n")
            st.write(current_system_prompt)
            st.write(current_prompt)


if st.button("Confirm and Run Questionnaire", type="primary", use_container_width=True):
    st.write("Starting inference...")

    openai_api_key = "EMPTY"
    openai_api_base = "http://localhost:8000/v1"

    client = AsyncOpenAI(**st.session_state.client_config)

    inference_config = st.session_state.inference_config.copy()

    model_name = inference_config.pop("model")

    progress_text = st.empty()

    log_queue = queue.Queue()
    result_queue = queue.Queue()

    class QueueWriter:
        def __init__(self, q):
            self.q = q

        def write(self, message):
            if message.strip():
                self.q.put(message)

        def flush(self):
            # This function is needed to match the file-like object interface
            # but we don't need to do anything here.
            pass

    # Helper function for asyncronous runs
    def run_async_in_thread(
        result_q, client, questionnaires, model_name, survey_method_name, **inference_config
    ):
        queue_writer = QueueWriter(log_queue)

        # We need to redirect the output to a queue, as streamlit does not support multithreading
        # API concurrency should be  configurable in the GUI
        try:
            with redirect_stderr(queue_writer):
                # Select the appropriate survey method based on user choice
                if survey_method_name == "single_item":
                    survey_func = conduct_survey_single_item
                elif survey_method_name == "battery":
                    survey_func = conduct_survey_battery
                elif survey_method_name == "sequential":
                    survey_func = conduct_survey_sequential
                else:
                    survey_func = conduct_survey_single_item  # Default fallback
                
                result = survey_func(
                    client,
                    llm_prompts=questionnaires,
                    client_model_name=model_name,
                    api_concurrency=100,
                    **inference_config,
                )

        except Exception as e:
            result = e
            st.error(e)
        finally:
            result_q.put(result)

    while not log_queue.empty():
        log_queue.get()
    while not result_queue.empty():
        result_queue.get()  

    # Get the selected survey method
    survey_method_display = st.session_state.get("survey_method", "Single item")
    survey_method_options = {
        "Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM),
        "Battery": ("battery", QuestionnairePresentation.BATTERY),
        "Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL),
    }
    selected_method_name, _ = survey_method_options.get(survey_method_display, ("single_item", QuestionnairePresentation.SINGLE_ITEM))
    
    thread = threading.Thread(
        target=run_async_in_thread,
        args=(result_queue, client, st.session_state.questionnaires, model_name, selected_method_name),
        kwargs=inference_config,
    )
    thread.start()

    all_questions_placeholder = st.empty()
    progress_placeholder = st.empty()

    while thread.is_alive():
        try:
            # Here we can write directly to the UI, as it is the main thread
            # TQDM uses carriage returns (\r) to animate in the console, we only show clear lines
            log_message = log_queue.get_nowait()
            # This is quite a hacky solution for now, we should adjust QSTN to make the messages clearly parsable.
            if "[A" not in log_message and "Processing Prompts" not in log_message:
                all_questions_placeholder.text(log_message.strip().replace("\r", ""))

            elif "Processing Prompts" in log_message:
                progress_placeholder.text(log_message.strip().replace("\r", ""))

        except queue.Empty:
            pass
        time.sleep(0.1)
    thread.join()

    all_questions_placeholder.empty()
    progress_placeholder.empty()

    try:
        final_output = result_queue.get_nowait()
    except queue.Empty:
        st.error("Could not retrieve result from the asynchronous task.")

    st.success("Finished inferencing!")

    responses = raw_responses(final_output)

    df = create_one_dataframe(responses)

    # Store the dataframe in session state for saving later
    st.session_state.results_dataframe = df
    st.session_state.inference_completed = True

    st.dataframe(df)

# Show save button if inference is completed
if "inference_completed" in st.session_state and st.session_state.inference_completed:
    st.divider()
    st.subheader("💾 Save Results")
    
    # Text input for filename
    if "save_filename" not in st.session_state:
        st.session_state.save_filename = "questionnaire_results.csv"
    
    save_filename = st.text_input(
        "Save File",
        value=st.session_state.save_filename,
        key="save_filename_input",
        help="Enter the filename for the results. Should be a CSV file (e.g., results.csv)."
    )
    
    # Ensure filename ends with .csv
    if save_filename and not save_filename.endswith('.csv'):
        save_filename = save_filename + '.csv'
    
    # Convert dataframe to CSV string for download
    csv = st.session_state.results_dataframe.to_csv(index=False)
    
    st.download_button(
        label="Save Results",
        data=csv,
        file_name=save_filename if save_filename else "questionnaire_results.csv",
        mime="text/csv",
        type="primary",
        use_container_width=True,
        help="Click to save the results to your computer. You can choose the directory and filename in the save dialog."
    )