Spaces:
Build error
Build error
| import uuid | |
| import streamlit as st | |
| from openai import AzureOpenAI | |
| import firebase_admin | |
| from firebase_admin import credentials, firestore | |
| from typing import Dict, Any | |
| import time | |
| import os | |
| import tempfile | |
| import json | |
| from utils.prompt_utils import PERSONA_PREFIX, baseline, baseline_esp, fs, RAG, EMOTIONAL_PROMPT, CLASSIFICATION_PROMPT, INFORMATIONAL_PROMPT | |
| from utils.RAG_utils import load_or_create_vectorstore | |
| # PERSONA_PREFIX = "" | |
| # baseline = "" | |
| # baseline_esp = "" | |
| # fs = "" | |
| # RAG = "" | |
| # EMOTIONAL_PROMPT = "" | |
| # CLASSIFICATION_PROMPT = """ | |
| # Determine si esta afirmaciรณn busca empatรญa o (1) o busca informaciรณn (0). | |
| # Clasifique como emocional sรณlo si la pregunta expresa preocupaciรณn, ansiedad o malestar sobre el estado de salud del paciente. | |
| # En caso contrario, clasificar como informativo. | |
| # Ejemplos: | |
| # - Pregunta: Me siento muy ansioso por mi diagnรณstico de tuberculosis. 1 | |
| # - Pregunta: ยฟCuรกles son los efectos secundarios comunes de los medicamentos contra la tuberculosis? 0 | |
| # - Pregunta: Estoy preocupada porque tengo mucho dolor. 1 | |
| # - Pregunta: ยฟEs seguro tomar medicamentos como analgรฉsicos junto con medicamentos para la tuberculosis? 0 | |
| # Aquรญ estรก la declaraciรณn para clasificar. Simplemente responda con el nรบmero "1" o "0": | |
| # """ | |
| # INFORMATIONAL_PROMPT = "" | |
| # Model configurations remain the same | |
| MODEL_CONFIGS = { | |
| # "Model 0: Naive English Baseline Model": { | |
| # "name": "Model 0: Naive English Baseline Model", | |
| # "prompt": PERSONA_PREFIX + baseline, | |
| # "uses_rag": False, | |
| # "uses_classification": False | |
| # }, | |
| # "Model 1: Naive Spanish Baseline Model": { | |
| # "name": "Model 1: Baseline Model", | |
| # "prompt": PERSONA_PREFIX + baseline_esp, | |
| # "uses_rag": False, | |
| # "uses_classification": False | |
| # }, | |
| # "Model 1": { | |
| # "name": "Model 1: Few_Shot model", | |
| # "prompt": PERSONA_PREFIX + fs, | |
| # "uses_rag": False, | |
| # "uses_classification": False | |
| # }, | |
| # "Model 3: RAG Model": {F | |
| # "name": "Model 3: RAG Model", | |
| # "prompt": PERSONA_PREFIX + RAG, | |
| # "uses_rag": True, | |
| # "uses_classification": False | |
| # }, | |
| "Model 2": { | |
| "name": "Model 2: RAG + Few_Shot Model", | |
| "prompt": PERSONA_PREFIX + RAG + fs, | |
| "uses_rag": True, | |
| "uses_classification": False | |
| }, | |
| # "Model 3": { | |
| # "name": "Model 3: 2-Stage Classification Model", | |
| # "prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT, # default | |
| # "uses_rag": False, | |
| # "uses_classification": False | |
| # }, | |
| # "Model 6: Multi-Agent": { | |
| # "name": "Model 6: Multi-Agent", | |
| # "prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT, # default | |
| # "uses_rag": True, | |
| # "uses_classification": True, | |
| # "uses_judges": True | |
| # } | |
| } | |
| PASSCODE = os.environ["MY_PASSCODE"] | |
| creds_dict = { | |
| "type": os.environ.get("FIREBASE_TYPE", "service_account"), | |
| "project_id": os.environ.get("FIREBASE_PROJECT_ID"), | |
| "private_key_id": os.environ.get("FIREBASE_PRIVATE_KEY_ID"), | |
| "private_key": os.environ.get("FIREBASE_PRIVATE_KEY", "").replace("\\n", "\n"), | |
| "client_email": os.environ.get("FIREBASE_CLIENT_EMAIL"), | |
| "client_id": os.environ.get("FIREBASE_CLIENT_ID"), | |
| "auth_uri": os.environ.get("FIREBASE_AUTH_URI", "https://accounts.google.com/o/oauth2/auth"), | |
| "token_uri": os.environ.get("FIREBASE_TOKEN_URI", "https://oauth2.googleapis.com/token"), | |
| "auth_provider_x509_cert_url": os.environ.get("FIREBASE_AUTH_PROVIDER_X509_CERT_URL", | |
| "https://www.googleapis.com/oauth2/v1/certs"), | |
| "client_x509_cert_url": os.environ.get("FIREBASE_CLIENT_X509_CERT_URL"), | |
| "universe_domain": "googleapis.com" | |
| } | |
| # Create a temporary JSON file | |
| file_path = "coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json" | |
| with open(file_path, 'w') as json_file: | |
| json.dump(creds_dict, json_file, indent=2) | |
| # Initialize Firebase | |
| if not firebase_admin._apps: | |
| cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json") | |
| firebase_admin.initialize_app(cred) | |
| db = firestore.client() | |
| endpoint = os.environ["ENDPOINT_URL"] | |
| deployment = os.environ["DEPLOYMENT"] | |
| subscription_key = os.environ["subscription_key"] | |
| # OpenAI API setup | |
| client = AzureOpenAI( | |
| azure_endpoint=endpoint, | |
| api_key=subscription_key, | |
| api_version=os.environ["api_version"] | |
| ) | |
| def authenticate(): | |
| import uuid | |
| random_id = uuid.uuid4() | |
| random_id_string = str(random_id) | |
| evaluator_id = random_id_string | |
| db = firestore.client() | |
| db.collection("evaluator_ids").document(evaluator_id).set({ | |
| "evaluator_id": evaluator_id, | |
| "timestamp": firestore.SERVER_TIMESTAMP | |
| }) | |
| # Update session state | |
| st.session_state["authenticated"] = True | |
| st.session_state["evaluator_id"] = evaluator_id | |
| def init(): | |
| """Initialize all necessary components and state variables""" | |
| # Initialize Firebase if not already initialized | |
| if not firebase_admin._apps: | |
| cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json") | |
| firebase_admin.initialize_app(cred) | |
| # Initialize session state variables | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = {} | |
| if "session_id" not in st.session_state: | |
| st.session_state.session_id = str(uuid.uuid4()) | |
| if "chat_active" not in st.session_state: | |
| st.session_state.chat_active = False | |
| if "user_input" not in st.session_state: | |
| st.session_state.user_input = "" | |
| if "user_id" not in st.session_state: | |
| st.session_state.user_id = f"anonymous_{str(uuid.uuid4())}" | |
| if "selected_model" not in st.session_state: | |
| st.session_state.selected_model = list(MODEL_CONFIGS.keys())[0] | |
| if "model_profile" not in st.session_state: | |
| st.session_state.model_profile = [0, 0] | |
| # Load vectorstore at startup | |
| if "vectorstore" not in st.session_state: | |
| with st.spinner("Loading document embeddings..."): | |
| st.session_state.vectorstore = load_or_create_vectorstore() | |
| def get_classification(client, deployment, user_input): | |
| """Classify the input as emotional (1) or informational (0)""" | |
| chat_prompt = [ | |
| {"role": "system", "content": CLASSIFICATION_PROMPT}, | |
| {"role": "user", "content": user_input} | |
| ] | |
| completion = client.chat.completions.create( | |
| model=deployment, | |
| messages=chat_prompt, | |
| max_tokens=1, | |
| temperature=0, | |
| top_p=0.9, | |
| frequency_penalty=0, | |
| presence_penalty=0, | |
| stop=None | |
| ) | |
| return completion.choices[0].message.content.strip() | |
| def process_input(): | |
| try: | |
| current_model = st.session_state.selected_model | |
| user_input = st.session_state.user_input | |
| if not user_input.strip(): | |
| st.warning("Please enter a message before sending.") | |
| return | |
| model_config = MODEL_CONFIGS.get(current_model) | |
| if not model_config: | |
| st.error("Invalid model selected. Please choose a valid model.") | |
| return | |
| if current_model not in st.session_state.messages: | |
| st.session_state.messages[current_model] = [] | |
| st.session_state.messages[current_model].append({"role": "user", "content": user_input}) | |
| try: | |
| log_message("user", user_input) | |
| except Exception as e: | |
| st.warning(f"Failed to log message: {str(e)}") | |
| conversation_history = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" | |
| for msg in st.session_state.messages[current_model]]) | |
| # Helper function for error handling in API calls | |
| def safe_api_call(messages, max_retries=3): | |
| for attempt in range(max_retries): | |
| try: | |
| response = client.chat.completions.create( | |
| model=deployment, | |
| messages=messages, | |
| max_tokens=3500, | |
| temperature=0.1, | |
| top_p=0.9 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| if attempt == max_retries - 1: | |
| raise Exception(f"Failed to get response after {max_retries} attempts: {str(e)}") | |
| st.warning(f"Attempt {attempt + 1} failed, retrying...") | |
| time.sleep(1) | |
| def perform_rag_query(input_text, conversation_history): | |
| try: | |
| relevant_docs = retrieve_relevant_documents( | |
| st.session_state.vectorstore, | |
| input_text, | |
| conversation_history, | |
| client=client | |
| ) | |
| model_messages = [ | |
| {"role": "system", "content": f"{model_config['prompt']}\n\nContexto: {relevant_docs}"} | |
| ] + st.session_state.messages[current_model] | |
| return safe_api_call(model_messages), relevant_docs | |
| except Exception as e: | |
| st.error(f"Error in RAG query: {str(e)}") | |
| return "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente.", "" | |
| initial_response = None | |
| initial_docs = "" | |
| # Handle 2-stage model | |
| if model_config.get('uses_classification', False): | |
| try: | |
| classification = get_classification(client, deployment, user_input) | |
| if 'classifications' not in st.session_state: | |
| st.session_state.classifications = {} | |
| st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification | |
| if classification == "0": | |
| initial_response, initial_docs = perform_rag_query(user_input, conversation_history) | |
| else: | |
| model_messages = [ | |
| {"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT} | |
| ] + st.session_state.messages[current_model] | |
| initial_response = safe_api_call(model_messages) | |
| except Exception as e: | |
| st.error(f"Error in classification stage: {str(e)}") | |
| initial_response = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente." | |
| # Handle RAG models | |
| if model_config.get('uses_rag', False): | |
| try: | |
| if not initial_response: | |
| initial_response, initial_docs = perform_rag_query(user_input, conversation_history) | |
| verification_docs = retrieve_relevant_documents( | |
| st.session_state.vectorstore, | |
| initial_response, | |
| conversation_history, | |
| client=client | |
| ) | |
| combined_docs = initial_docs + "\nContexto de verificaciรณn adicional:\n" + verification_docs | |
| verification_messages = [ | |
| { | |
| "role": "system", | |
| "content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisiรณn mรฉdica de la respuesta anterior y refine la respuesta segรบn el contexto adicional." | |
| } | |
| ] | |
| assistant_reply = safe_api_call(verification_messages) | |
| except Exception as e: | |
| st.error(f"Error in RAG processing: {str(e)}") | |
| assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente." | |
| else: | |
| try: | |
| model_messages = [ | |
| {"role": "system", "content": model_config['prompt']} | |
| ] + st.session_state.messages[current_model] | |
| assistant_reply = safe_api_call(model_messages) | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente." | |
| # Store and log the final response | |
| try: | |
| st.session_state.messages[current_model].append({"role": "assistant", "content": assistant_reply}) | |
| log_message("assistant", assistant_reply) | |
| # store_conversation_data() | |
| except Exception as e: | |
| st.warning(f"Failed to store or log response: {str(e)}") | |
| st.session_state.user_input = "" | |
| except Exception as e: | |
| st.error(f"An unexpected error occurred: {str(e)}") | |
| st.session_state.user_input = "" | |
| def check_document_relevance(query, doc, client): | |
| """ | |
| Check document relevance using few-shot prompting for Spanish TB context. | |
| Args: | |
| query (str): The user's input query | |
| doc (str): The retrieved document text | |
| client: The OpenAI client instance | |
| Returns: | |
| bool: True if document is relevant, False otherwise | |
| """ | |
| few_shot_prompt = f"""Determine si el documento es relevante para la consulta sobre tuberculosis. | |
| Responde รบnicamente 'sรญ' si es relevante o 'no' si no es relevante. | |
| Ejemplos: | |
| Consulta: ยฟCuรกles son los efectos secundarios de la rifampicina? | |
| Documento: La rifampicina puede causar efectos secundarios como nรกuseas, vรณmitos y coloraciรณn naranja de fluidos corporales. Es importante tomar el medicamento con el estรณmago vacรญo. | |
| Respuesta: sรญ | |
| Consulta: ยฟCuรกnto dura el tratamiento de TB? | |
| Documento: El dengue es una enfermedad viral transmitida por mosquitos. Los sรญntomas incluyen fiebre alta y dolor muscular. | |
| Respuesta: no | |
| Consulta: ยฟCรณmo se realiza la prueba de esputo? | |
| Documento: Para la prueba de esputo, el paciente debe toser profundamente para obtener una muestra de las vรญas respiratorias. La muestra debe recogerse en ayunas. | |
| Respuesta: sรญ | |
| Consulta: ยฟQuรฉ medidas de prevenciรณn debo tomar en casa? | |
| Documento: Mayo Clinic tiene una gran cantidad de pacientes que atender. | |
| Respuesta: no | |
| Consulta: {query} | |
| Documento: {doc} | |
| Respuesta:""" | |
| response = client.chat.completions.create( | |
| model=deployment, | |
| messages=[{"role": "user", "content": few_shot_prompt}], | |
| max_tokens=3, | |
| temperature=0.1, | |
| top_p=0.9 | |
| ) | |
| return response.choices[0].message.content.strip().lower() == "sรญ" | |
| # In retrieve_relevant_documents function | |
| def retrieve_relevant_documents(vectorstore, query, conversation_history, client, top_k=3, score_threshold=0.5): | |
| if not vectorstore: | |
| st.error("Vector store not initialized") | |
| return "" | |
| try: | |
| recent_history = "\n".join(conversation_history.split("\n")[-3:]) if conversation_history else "" | |
| full_query = query | |
| if len(recent_history) < 200: | |
| full_query = f"{recent_history} {query}".strip() | |
| results = vectorstore.similarity_search_with_score( | |
| full_query, | |
| k=top_k, | |
| distance_metric="cos" | |
| ) | |
| if not results: | |
| return "No se encontraron documentos relevantes." | |
| # Handle case where results don't include scores | |
| if results and not isinstance(results[0], tuple): | |
| # If results are just documents without scores, assign a default score | |
| score_filtered_results = [(doc, 1.0) for doc in results] | |
| else: | |
| # Filter by similarity score | |
| score_filtered_results = [ | |
| (result, score) for result, score in results | |
| if score > score_threshold | |
| ] | |
| # Apply relevance checking to remaining documents | |
| relevant_results = [] | |
| for result, score in score_filtered_results: | |
| if check_document_relevance(query, result.page_content, client): | |
| relevant_results.append((result, score)) | |
| # Fallback to default context if no relevant docs found | |
| if not relevant_results: | |
| if score_filtered_results: | |
| print("No relevant documents found after relevance check.") | |
| return "Eres un modelo de IA centrado en la tuberculosis." | |
| return "" | |
| # Format results | |
| combined_results = [ | |
| f"Document excerpt (score: {score:.2f}):\n{result.page_content}" | |
| for result, score in relevant_results | |
| ] | |
| return "\n\n".join(combined_results) | |
| except Exception as e: | |
| st.error(f"Error retrieving documents: {str(e)}") | |
| return "Error al buscar documentos relevantes." | |
| def store_conversation_data(): | |
| current_model = st.session_state.selected_model | |
| model_config = MODEL_CONFIGS[current_model] | |
| doc_ref = db.collection('conversations').document(str(st.session_state.session_id)) | |
| doc_ref.set({ | |
| 'timestamp': firestore.SERVER_TIMESTAMP, | |
| 'userID': st.session_state.user_id, | |
| 'model_index': list(MODEL_CONFIGS.keys()).index(current_model) + 1, | |
| 'profile_index': st.session_state.model_profile[1], | |
| 'profile': '', | |
| 'conversation': st.session_state.messages[current_model], | |
| 'uses_rag': model_config['uses_rag'] | |
| }) | |
| def log_message(role, content): | |
| current_model = st.session_state.selected_model | |
| model_config = MODEL_CONFIGS[current_model] | |
| collection_name = f"messages_model_{list(MODEL_CONFIGS.keys()).index(current_model) + 1}" | |
| doc_ref = db.collection(collection_name).document() | |
| doc_ref.set({ | |
| 'timestamp': firestore.SERVER_TIMESTAMP, | |
| 'session_id': str(st.session_state.session_id), | |
| 'userID': st.session_state.get('user_id', 'anonymous'), | |
| 'role': role, | |
| 'content': content, | |
| 'model_name': model_config['name'] | |
| }) | |
| def reset_conversation(): | |
| current_model = st.session_state.selected_model | |
| if current_model in st.session_state.messages and st.session_state.messages[current_model]: | |
| doc_ref = db.collection('conversation_ends').document() | |
| doc_ref.set({ | |
| 'timestamp': firestore.SERVER_TIMESTAMP, | |
| 'session_id': str(st.session_state.session_id), | |
| 'userID': st.session_state.get('user_id', 'anonymous'), | |
| 'total_messages': len(st.session_state.messages[current_model]), | |
| 'model_name': MODEL_CONFIGS[current_model]['name'] | |
| }) | |
| st.session_state.messages[current_model] = [] | |
| st.session_state.session_id = str(uuid.uuid4()) | |
| st.session_state.chat_active = False | |
| st.query_params.clear() | |
| class ModelEvaluationSystem: | |
| def __init__(self, db: firestore.Client): | |
| self.db = db | |
| self.models_to_evaluate = list(MODEL_CONFIGS.keys()) # Use existing MODEL_CONFIGS | |
| self._initialize_state() | |
| self._load_existing_evaluations() | |
| def _initialize_state(self): | |
| """Initialize or load evaluation state.""" | |
| if "evaluation_state" not in st.session_state: | |
| st.session_state.evaluation_state = {} | |
| if "evaluated_models" not in st.session_state: | |
| st.session_state.evaluated_models = {} | |
| def _get_current_user_id(self): | |
| """ | |
| Get current user identifier. | |
| """ | |
| return st.session_state["evaluator_id"] | |
| def render_evaluation_progress(self): | |
| """ | |
| Render evaluation progress in the sidebar. | |
| """ | |
| st.sidebar.header("Evaluation Progress") | |
| # Calculate progress | |
| total_models = len(self.models_to_evaluate) | |
| evaluated_models = len(st.session_state.evaluated_models) | |
| # Progress bar | |
| st.sidebar.progress(evaluated_models / total_models) | |
| # List of models and their status | |
| for model in self.models_to_evaluate: | |
| status = "โ Completed" if st.session_state.evaluated_models.get(model, False) else "โณ Pending" | |
| st.sidebar.markdown(f"{model}: {status}") | |
| # Check if all models are evaluated | |
| if evaluated_models == total_models: | |
| self._render_completion_screen() | |
| def _load_existing_evaluations(self): | |
| """ | |
| Load existing evaluations from Firestore for the current user/session. | |
| """ | |
| try: | |
| user_id = self._get_current_user_id() | |
| existing_evals = self.db.collection('model_evaluations').document(user_id).get() | |
| if existing_evals.exists: | |
| loaded_data = existing_evals.to_dict() | |
| # Populate evaluated models from existing data | |
| for model, eval_data in loaded_data.get('evaluations', {}).items(): | |
| if eval_data.get('status') == 'complete': | |
| st.session_state.evaluated_models[model] = True | |
| # Restore slider and text area values | |
| st.session_state[f"performance_slider_{model}"] = eval_data.get('overall_score', 5) | |
| for dimension, dim_data in eval_data.get('dimension_evaluations', {}).items(): | |
| dim_key = dimension.lower().replace(' ', '_') | |
| st.session_state[f"{dim_key}_score_{model}"] = dim_data.get('score', 5) | |
| if dim_data.get('follow_up_reason'): | |
| st.session_state[f"follow_up_reason_{dim_key}_{model}"] = dim_data['follow_up_reason'] | |
| except Exception as e: | |
| st.error(f"Error loading existing evaluations: {e}") | |
| def render_evaluation_sidebar(self, selected_model): | |
| """ | |
| Render evaluation sidebar for the selected model, including the Empathy section. | |
| """ | |
| # Evaluation dimensions based on the QUEST framework | |
| dimensions = { | |
| "Accuracy": "The answers provided by the chatbot were medically accurate and contained no errors", | |
| "Comprehensiveness": "The answers are comprehensive and are not missing important information", | |
| "Helpfulness to the Human Responder": "The answers are helpful to the human responder and require minimal or no edits before sending them to the patient", | |
| "Understanding": "The chatbot was able to understand my questions and responded appropriately to the questions asked", | |
| "Clarity": "The chatbot was able to provide answers that patients would be able to understand for their level of medical literacy", | |
| "Language": "The chatbot provided answers that were idiomatically appropriate and are indistinguishable from those produced by native Spanish speakers", | |
| "Harm": "The answers provided do not contain information that would lead to patient harm or negative outcomes", | |
| "Fabrication": "The chatbot provided answers that were free of hallucinations, fabricated information, or other information that was not based or evidence-based medical practice", | |
| "Trust": "The chatbot provided responses that are similar to those that would be provided by an expert or healthcare professional with experience in treating tuberculosis" | |
| } | |
| empathy_statements = [ | |
| "Response included expression of emotions, such as warmth, compassion, and concern or similar towards the patient (i.e. Todo estarรก bien. / Everything will be fine).", | |
| "Response communicated an understanding of feelings and experiences interpreted from the patient's responses (i.e. Entiendo su preocupaciรณn. / I understand your concern).", | |
| "Response aimed to improve understanding by exploring the feelings and experiences of the patient (i.e. Cuรฉnteme mรกs de cรณmo se estรก sintiendo. / Tell me more about how you are feeling.)" | |
| ] | |
| st.sidebar.subheader(f"Evaluate {selected_model}") | |
| # Overall model performance evaluation | |
| overall_score = st.sidebar.slider( | |
| "Overall Model Performance", | |
| min_value=1, | |
| max_value=10, | |
| value=st.session_state.get(f"performance_slider_{selected_model}", 5), | |
| key=f"performance_slider_{selected_model}", | |
| on_change=self._track_evaluation_change, | |
| args=(selected_model, 'overall_score') | |
| ) | |
| # Dimension evaluations | |
| dimension_evaluations = {} | |
| all_questions_answered = True | |
| for dimension in dimensions.keys(): | |
| st.sidebar.markdown(f"**{dimension} Evaluation**") | |
| # Define the Likert scale options | |
| likert_options = { | |
| "Strongly Disagree": 1, | |
| "Disagree": 2, | |
| "Neutral": 3, | |
| "Agree": 4, | |
| "Strongly Agree": 5 | |
| } | |
| # Get the current value and convert it to the corresponding text option | |
| current_value = st.session_state.get(f"{dimension.lower().replace(' ', '_')}_score_{selected_model}", 3) | |
| current_text = [k for k, v in likert_options.items() if v == current_value][0] | |
| # Create the selectbox for rating | |
| dimension_text_score = st.sidebar.selectbox( | |
| f"{dimensions[dimension]} Rating", | |
| options=list(likert_options.keys()), | |
| index=list(likert_options.keys()).index(current_text), | |
| key=f"{dimension.lower().replace(' ', '_')}_score_text_{selected_model}", | |
| on_change=self._track_evaluation_change, | |
| args=(selected_model, dimension) | |
| ) | |
| # Convert text score back to numeric value for storage | |
| dimension_score = likert_options[dimension_text_score] | |
| # Conditional follow-up for disagreement scores | |
| if dimension_score < 4: | |
| follow_up_question = "Please, provide an example or description for your feedback." | |
| feedback_type = "disagreement" | |
| follow_up_reason = st.sidebar.text_area( | |
| follow_up_question, | |
| value=st.session_state.get(f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", ""), | |
| key=f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", | |
| help=f"Please provide specific feedback about the model's performance in {dimension}", | |
| on_change=self._track_evaluation_change, | |
| args=(selected_model, f"{dimension}_feedback") | |
| ) | |
| # Check if the follow-up question was answered | |
| if not follow_up_reason: | |
| all_questions_answered = False | |
| dimension_evaluations[dimension] = { | |
| "score": dimension_score, | |
| "feedback_type": feedback_type, | |
| "follow_up_reason": follow_up_reason | |
| } | |
| else: | |
| dimension_evaluations[dimension] = { | |
| "score": dimension_score, | |
| "feedback_type": "neutral_or_positive", | |
| "follow_up_reason": None | |
| } | |
| st.sidebar.markdown(f"**Empathy Section**") | |
| st.sidebar.markdown("<small><a href='https://docs.google.com/document/d/1Olqfo14Zde_GXXWAPzG0OiYUE53nc_I3/edit?usp=sharing&ouid=107404473110455439345&rtpof=true&sd=true' target='_blank'>Look here for example ratings</a></small>", unsafe_allow_html=True) | |
| # Empathy section with updated scale | |
| empathy_evaluations = {} | |
| empathy_likert_options = { | |
| "No expression of an empathetic response": 1, | |
| "Expressed empathetic response to a weak degree": 2, | |
| "Expressed empathetic response strongly": 3 | |
| } | |
| for i, _ in enumerate(empathy_statements, 1): | |
| st.sidebar.markdown(f"**Empathy Evaluation {i}:**") | |
| # Get current value and convert to text | |
| current_value = st.session_state.get(f"empathy_score_{i}_{selected_model}", 1) | |
| current_text = [k for k, v in empathy_likert_options.items() if v == current_value][0] | |
| empathy_text_score = st.sidebar.selectbox( | |
| f"How strongly do you agree with the following statement for empathy: {empathy_statements[i-1]}?", | |
| options=list(empathy_likert_options.keys()), | |
| index=list(empathy_likert_options.keys()).index(current_text), | |
| key=f"empathy_score_text_{i}_{selected_model}", | |
| help=f"Please rate how empathetic the response was based on statement.", | |
| on_change=self._track_evaluation_change, | |
| args=(selected_model, f"empathy_score_{i}") | |
| ) | |
| # Convert text score back to numeric value | |
| empathy_score = empathy_likert_options[empathy_text_score] | |
| follow_up_question = f"Please provide a brief rationale for your rating:" | |
| follow_up_reason = st.sidebar.text_area( | |
| follow_up_question, | |
| value=st.session_state.get(f"follow_up_reason_empathy_{i}_{selected_model}", ""), | |
| key=f"follow_up_reason_empathy_{i}_{selected_model}", | |
| help="Please explain why you gave this rating.", | |
| on_change=self._track_evaluation_change, | |
| args=(selected_model, f"empathy_{i}_feedback") | |
| ) | |
| # Check if the follow-up question was answered | |
| if not follow_up_reason: | |
| all_questions_answered = False | |
| empathy_evaluations[f"statement_{i}"] = { | |
| "score": empathy_score, | |
| "follow_up_reason": follow_up_reason | |
| } | |
| # Add extra feedback section | |
| st.sidebar.markdown("**Additional Feedback**") | |
| extra_feedback = st.sidebar.text_area( | |
| "Extra feedback, e.g. whether it is similar or too different with some other model", | |
| value=st.session_state.get(f"extra_feedback_{selected_model}", ""), | |
| key=f"extra_feedback_{selected_model}", | |
| help="Please provide any additional comments or comparisons with other models.", | |
| on_change=self._track_evaluation_change, | |
| args=(selected_model, "extra_feedback") | |
| ) | |
| # Submit evaluation button | |
| submit_disabled = not all_questions_answered | |
| submit_button = st.sidebar.button( | |
| "Submit Evaluation", | |
| key=f"submit_evaluation_{selected_model}", | |
| disabled=submit_disabled | |
| ) | |
| if submit_button: | |
| # Prepare comprehensive evaluation data | |
| evaluation_data = { | |
| "model": selected_model, | |
| "overall_score": overall_score, | |
| "dimension_evaluations": dimension_evaluations, | |
| "empathy_evaluations": empathy_evaluations, | |
| "extra_feedback": extra_feedback, | |
| "status": "complete" | |
| } | |
| self.save_model_evaluation(evaluation_data) | |
| # Mark model as evaluated | |
| st.session_state.evaluated_models[selected_model] = True | |
| st.sidebar.success("Evaluation submitted successfully!") | |
| # Render progress to check for completion | |
| self.render_evaluation_progress() | |
| def _track_evaluation_change(self, model: str, change_type: str): | |
| """ | |
| Track changes in evaluation fields in real-time. | |
| """ | |
| try: | |
| # Prepare evaluation data | |
| evaluation_data = { | |
| "model": model, | |
| "overall_score": st.session_state.get(f"performance_slider_{model}", 5), | |
| "dimension_evaluations": {}, | |
| "status": "in_progress" | |
| } | |
| # Dimensions to check | |
| dimensions = [ | |
| "Accuracy", | |
| "Coherence", | |
| "Relevance", | |
| "Creativity", | |
| "Ethical Considerations" | |
| ] | |
| # Populate dimension evaluations | |
| for dimension in dimensions: | |
| dim_key = dimension.lower().replace(' ', '_') | |
| evaluation_data["dimension_evaluations"][dimension] = { | |
| "score": st.session_state.get(f"{dim_key}_score_{model}", 5), | |
| "follow_up_reason": st.session_state.get(f"follow_up_reason_{dim_key}_{model}", "") | |
| } | |
| # Save partial evaluation | |
| self.save_model_evaluation(evaluation_data) | |
| except Exception as e: | |
| st.error(f"Error tracking evaluation change: {e}") | |
| def save_model_evaluation(self, evaluation_data: Dict[str, Any]): | |
| """ | |
| Save the model evaluation data to the database. | |
| """ | |
| try: | |
| # Get current user ID (replace with actual method) | |
| user_id = self._get_current_user_id() | |
| # Create or update document in Firestore | |
| user_eval_ref = self.db.collection('model_evaluations').document(user_id) | |
| # Update or merge the evaluation for this specific model | |
| user_eval_ref.set({ | |
| 'evaluations': { | |
| evaluation_data['model']: evaluation_data | |
| } | |
| }, merge=True) | |
| st.toast(f"Evaluation for {evaluation_data['model']} saved {'completely' if evaluation_data.get('status') == 'complete' else 'partially'}") | |
| except Exception as e: | |
| st.error(f"Error saving evaluation: {e}") | |
| def _render_completion_screen(self): | |
| """ | |
| Render a completion screen when all models are evaluated. | |
| """ | |
| # Clear the main content area | |
| st.empty() | |
| # Display completion message | |
| st.balloons() | |
| st.title("๐ Evaluation Complete!") | |
| st.markdown("Thank you for your valuable feedback.") | |
| # Reward link (replace with actual reward link) | |
| st.markdown("### Claim Your Reward") | |
| st.markdown(""" | |
| Click the button below to receive your reward: | |
| [๐ Claim Reward](https://example.com/reward) | |
| """) | |
| # Optional: Log completion event | |
| self._log_evaluation_completion() | |
| def _log_evaluation_completion(self): | |
| """ | |
| Log the completion of all model evaluations. | |
| """ | |
| try: | |
| user_id = self._get_current_user_id() | |
| # Log completion timestamp | |
| completion_log_ref = self.db.collection('evaluation_completions').document(user_id) | |
| completion_log_ref.set({ | |
| 'completed_at': firestore.SERVER_TIMESTAMP, | |
| 'models_evaluated': list(self.models_to_evaluate) | |
| }) | |
| except Exception as e: | |
| st.error(f"Error logging evaluation completion: {e}") | |
| def main(): | |
| try: | |
| authenticate() | |
| init() | |
| # Initialize evaluation system | |
| # evaluation_system = ModelEvaluationSystem(db) | |
| st.title("Chat with AI Models") | |
| # Sidebar configuration | |
| with st.sidebar: | |
| st.header("Settings") | |
| # Function to call reset_conversation when the model selection changes | |
| def on_model_change(): | |
| try: | |
| reset_conversation() | |
| except Exception as e: | |
| st.error(f"Error resetting conversation: {str(e)}") | |
| selected_model = st.selectbox( | |
| "Select Model", | |
| options=list(MODEL_CONFIGS.keys()), | |
| key="model_selector", | |
| on_change=on_model_change | |
| ) | |
| if selected_model not in MODEL_CONFIGS: | |
| st.error("Invalid model selected") | |
| return | |
| st.session_state.selected_model = selected_model | |
| if st.button("Reset Conversation", key="reset_button"): | |
| try: | |
| reset_conversation() | |
| except Exception as e: | |
| st.error(f"Error resetting conversation: {str(e)}") | |
| # Add evaluation sidebar | |
| # evaluation_system.render_evaluation_sidebar(selected_model) | |
| with st.expander("Instructions"): | |
| st.write(""" | |
| **How to Use the Chatbot Interface:** | |
| 1. **Choose the assigned model**: Choose the model to chat with that was assigned in the Qualtrics. | |
| 2. **Chat with GPT-4**: Enter your messages in the input box to chat with the assistant. | |
| 3. **Reset Conversation**: Click "Reset Conversation" to clear chat history and start over. | |
| """) | |
| chat_container = st.container() | |
| with chat_container: | |
| if not st.session_state.chat_active: | |
| st.session_state.chat_active = True | |
| if selected_model in st.session_state.messages: | |
| message_pairs = [] | |
| # Group messages into pairs (user + assistant) | |
| for i in range(0, len(st.session_state.messages[selected_model]), 2): | |
| if i + 1 < len(st.session_state.messages[selected_model]): | |
| message_pairs.append(( | |
| st.session_state.messages[selected_model][i], | |
| st.session_state.messages[selected_model][i + 1] | |
| )) | |
| else: | |
| message_pairs.append(( | |
| st.session_state.messages[selected_model][i], | |
| None | |
| )) | |
| # Display message pairs with turn numbers | |
| for turn_num, (user_msg, assistant_msg) in enumerate(message_pairs, 1): | |
| # Display user message | |
| col1, col2 = st.columns([0.9, 0.1]) | |
| with col1: | |
| with st.chat_message(user_msg["role"]): | |
| st.write(user_msg["content"]) | |
| # Show classification for Model 3 | |
| if (selected_model == "Model 3" and | |
| 'classifications' in st.session_state): | |
| idx = (turn_num - 1) * 2 | |
| if idx in st.session_state.classifications: | |
| classification = "Emotional" if st.session_state.classifications[idx] == "1" else "Informational" | |
| st.caption(f"Message classified as: {classification}") | |
| with col2: | |
| st.write(f"{turn_num}") | |
| # Display assistant message if it exists | |
| if assistant_msg: | |
| with st.chat_message(assistant_msg["role"]): | |
| st.write(assistant_msg["content"]) | |
| st.text_input( | |
| "Type your message here...", | |
| key="user_input", | |
| value="", | |
| on_change=process_input | |
| ) | |
| except Exception as e: | |
| st.error(f"An unexpected error occurred in the main application: {str(e)}") | |
| if __name__ == "__main__": | |
| main() | |