Spaces:
Sleeping
Sleeping
| from db.schema import Response, ModelRatings | |
| import streamlit as st | |
| from datetime import datetime | |
| from dotenv import load_dotenv | |
| from views.nav_buttons import navigation_buttons | |
| import random | |
| from utils.loaders import load_html | |
| load_dotenv() | |
| def display_completion_message(): | |
| """Display a standardized survey completion message.""" | |
| st.markdown( | |
| """ | |
| <div class='exit-container'> | |
| <h1>You have already completed the survey! Thank you for participating!</h1> | |
| <p>Your responses have been saved successfully.</p> | |
| <p>You can safely close this window or start a new survey.</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.session_state.show_questions = False | |
| st.session_state.completed = True | |
| st.session_state.start_new_survey = True | |
| def get_previous_ratings(model_name, query_key, current_index): | |
| """Retrieve previous ratings from session state.""" | |
| previous_ratings = {} | |
| if current_index < st.session_state.current_index and len( | |
| st.session_state.responses | |
| ) > current_index: | |
| if st.session_state.previous_ratings: | |
| previous_ratings = st.session_state.previous_ratings.get( | |
| st.session_state.data.iloc[current_index]["config_id"], {} | |
| ) | |
| previous_ratings = previous_ratings.get( | |
| model_name, None | |
| ) # Fix: Model key from session state | |
| elif len(st.session_state.responses) <= current_index: | |
| previous_ratings = {} | |
| else: | |
| response_from_session = st.session_state.responses[current_index] | |
| try: | |
| previous_ratings = response_from_session.model_ratings.get(model_name, {}) | |
| except AttributeError: | |
| previous_ratings = response_from_session["model_ratings"].get(model_name, {}) | |
| stored_query_ratings = {} | |
| if previous_ratings: | |
| if "query_v" in query_key: | |
| try: | |
| stored_query_ratings = previous_ratings.query_v_ratings | |
| except AttributeError: | |
| stored_query_ratings = previous_ratings["query_v_ratings"] | |
| elif "query_p0" in query_key: | |
| try: | |
| stored_query_ratings = previous_ratings.query_p0_ratings | |
| except AttributeError: | |
| stored_query_ratings = previous_ratings["query_p0_ratings"] | |
| elif "query_p1" in query_key: | |
| try: | |
| stored_query_ratings = previous_ratings.query_p1_ratings | |
| except AttributeError: | |
| stored_query_ratings = previous_ratings["query_p1_ratings"] | |
| return stored_query_ratings if stored_query_ratings else {} | |
| def render_single_rating( | |
| label, | |
| options, | |
| format_func, | |
| key_prefix, | |
| stored_rating, | |
| col, | |
| ): | |
| """Renders a single rating widget (radio).""" | |
| with col: | |
| return st.radio( | |
| label, | |
| options=options, | |
| format_func=format_func, | |
| key=f"{key_prefix}", | |
| index=stored_rating if stored_rating is not None else None, | |
| ) | |
| def clean_query_text(query_text): | |
| """Clean the query text for display.""" | |
| if query_text.startswith('"') or query_text.startswith("'") or query_text.endswith('"') or query_text.endswith("'"): | |
| query_text = query_text.replace('"', '').replace("'", "") | |
| if query_text[-1] not in [".", "?", "!", "\n"]: | |
| query_text += "." | |
| return query_text.capitalize() | |
| def render_query_ratings( | |
| model_name, | |
| config, | |
| query_key, | |
| current_index, | |
| has_persona_alignment=False, | |
| ): | |
| """Helper function to render ratings for a given query.""" | |
| stored_query_ratings = get_previous_ratings(model_name, query_key, current_index) | |
| stored_groundedness = stored_query_ratings.get("groundedness", 0) | |
| stored_clarity = stored_query_ratings.get("clarity", 0) | |
| stored_overall_rating = stored_query_ratings.get("overall", 0) | |
| stored_persona_alignment = ( | |
| stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment else 0 | |
| ) | |
| if model_name == "gemini": | |
| bg_color = "#e0f7fa" | |
| else: | |
| bg_color = "#f0f4c3" | |
| query_text = clean_query_text(config[model_name + "_" + query_key]) | |
| with st.container(): | |
| st.markdown( | |
| f""" | |
| <div style="background-color:{bg_color}; padding:1rem;"> | |
| <h3 style="text-align:left;color:black;"> | |
| {config.index.get_loc(model_name + "_" + query_key) - 5} | |
| </h3> | |
| <p style="text-align:left;color:black;"> | |
| {query_text}</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| col_no = 4 if has_persona_alignment else 3 | |
| cols = st.columns(col_no) | |
| options = [0, 1, 2, 3, 4] | |
| groundedness_rating = render_single_rating( | |
| "Groundedness:", | |
| options, | |
| lambda x: ["N/A", "Not Grounded", "Partially Grounded", "Grounded", "Unclear"][ | |
| x | |
| ], | |
| f"rating_{model_name}{query_key}_groundedness_", | |
| stored_groundedness, | |
| cols[0], | |
| ) | |
| persona_alignment_rating = None | |
| if has_persona_alignment: | |
| persona_alignment_rating = render_single_rating( | |
| "Persona Alignment:", | |
| options, | |
| lambda x: ["N/A", "Not Aligned", "Partially Aligned", "Aligned", "Unclear"][ | |
| x | |
| ], | |
| f"rating_{model_name}{query_key}_persona_alignment_", | |
| stored_persona_alignment, | |
| cols[1], | |
| ) | |
| clarity_rating = render_single_rating( | |
| "Clarity:", | |
| [0, 1, 2, 3], | |
| lambda x: ["N/A", "Not Clear", "Somewhat Clear", "Very Clear"][x], | |
| f"rating_{model_name}{query_key}_clarity_", | |
| stored_clarity, | |
| cols[2] if has_persona_alignment else cols[1], | |
| ) | |
| overall_rating = render_single_rating( | |
| "Overall Fit:", | |
| [0, 1, 2, 3], | |
| lambda x: ["N/A", "Poor", "Moderate", "Strong Fit"][x], | |
| f"rating_{model_name}{query_key}_overall_", | |
| stored_overall_rating, | |
| cols[3] if has_persona_alignment else cols[2], | |
| ) | |
| return { | |
| "clarity": clarity_rating, | |
| "groundedness": groundedness_rating, | |
| "persona_alignment": persona_alignment_rating if has_persona_alignment else None, | |
| "overall": overall_rating, | |
| } | |
| def display_ratings_row(model_name, config, current_index): | |
| # st.markdown(f"## {model_name.capitalize()} Ratings") | |
| cols = st.columns(3) | |
| # combinations = ["query_v", "query_p0", "query_p1"] | |
| # random.shuffle(combinations) | |
| with cols[0]: | |
| query_v_ratings = render_query_ratings( | |
| model_name, | |
| config, | |
| "query_v", | |
| current_index, | |
| has_persona_alignment=False, | |
| ) | |
| with cols[1]: | |
| query_p0_ratings = render_query_ratings( | |
| model_name, | |
| config, | |
| "query_p0", | |
| current_index, | |
| has_persona_alignment=True, | |
| ) | |
| with cols[2]: | |
| query_p1_ratings = render_query_ratings( | |
| model_name, | |
| config, | |
| "query_p1", | |
| current_index, | |
| has_persona_alignment=True, | |
| ) | |
| if "persona_alignment" in query_v_ratings: | |
| query_v_ratings.pop("persona_alignment") | |
| return { | |
| "query_v_ratings": query_v_ratings, | |
| "query_p0_ratings": query_p0_ratings, | |
| "query_p1_ratings": query_p1_ratings, | |
| } | |
| def questions_screen(data): | |
| """Display the questions screen with split layout.""" | |
| current_index = st.session_state.current_index | |
| try: | |
| config = data.iloc[current_index] | |
| st.markdown(f"## Hello {st.session_state.username.title()} 👋") | |
| # Progress bar | |
| progress = (current_index + 1) / len(data) | |
| st.progress(progress) | |
| st.write(f"Question {current_index + 1} of {len(data)}") | |
| # st.subheader(f"Config ID: {config['config_id']}") | |
| st.markdown("### Instructions") | |
| instructions_html = load_html("static/instructions.html") | |
| with st.expander("Instructions", expanded=False): | |
| st.html(instructions_html) | |
| # Context information | |
| st.markdown("### Context Information") | |
| with st.expander("Persona", expanded=True): | |
| st.write(config["persona"]) | |
| with st.expander("Filters", expanded=True): | |
| st.code(config["filters"], language="json") | |
| # st.write("**Cities:**", config["city"]) | |
| # with st.expander("Full Context", expanded=False): | |
| # st.text_area("", config["context"], height=300, disabled=False) | |
| st.markdown("### Rate the following queries based on the above context.") | |
| g_ratings = display_ratings_row("gemini", config, current_index) | |
| l_ratings = display_ratings_row("llama", config, current_index) | |
| # Additional comments | |
| comment = st.text_area("Additional Comments (Optional):") | |
| # Collecting the response data | |
| response = Response( | |
| config_id=config["config_id"], | |
| model_ratings={ | |
| "gemini": ModelRatings( | |
| query_v_ratings=g_ratings["query_v_ratings"], | |
| query_p0_ratings=g_ratings["query_p0_ratings"], | |
| query_p1_ratings=g_ratings["query_p1_ratings"], | |
| ), | |
| "llama": ModelRatings( | |
| query_v_ratings=l_ratings["query_v_ratings"], | |
| query_p0_ratings=l_ratings["query_p0_ratings"], | |
| query_p1_ratings=l_ratings["query_p1_ratings"], | |
| ), | |
| }, | |
| comment=comment, | |
| timestamp=datetime.now().isoformat(), | |
| ) | |
| try: | |
| st.session_state.ratings[current_index] = response["model_ratings"] | |
| except TypeError: | |
| st.session_state.ratings[current_index] = response.model_ratings | |
| if len(st.session_state.responses) > current_index: | |
| st.session_state.responses[current_index] = response | |
| else: | |
| st.session_state.responses.append(response) | |
| # Navigation buttons | |
| navigation_buttons(data, response) | |
| except IndexError: | |
| print("Survey completed!") | |