Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| from datasets import load_dataset | |
| import json | |
| import re | |
| from openai import OpenAI | |
| import os | |
| from config import DATASETS, MODELS | |
| import matplotlib.pyplot as plt | |
| import altair as alt | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type | |
| import threading | |
| from anthropic import Anthropic | |
| import google.generativeai as genai | |
| import hmac | |
| import hashlib | |
| load_dotenv() | |
| def initialize_session_state(): | |
| if 'api_configured' not in st.session_state: | |
| st.session_state.api_configured = False | |
| if 'togetherai_client' not in st.session_state: | |
| st.session_state.togetherai_client = None | |
| if 'openai_client' not in st.session_state: | |
| st.session_state.openai_client = None | |
| if 'anthropic_client' not in st.session_state: | |
| st.session_state.anthropic_client = None | |
| def setup_api_clients(): | |
| initialize_session_state() | |
| with st.sidebar: | |
| st.title("API Configuration") | |
| use_stored = st.checkbox("Use the stored API keys") | |
| if use_stored: | |
| username = st.text_input("Username") | |
| password = st.text_input("Password", type="password") | |
| if st.button("Verify Credentials"): | |
| if (hmac.compare_digest(username, os.environ.get("STREAMLIT_USERNAME", "")) and | |
| hmac.compare_digest(password, os.environ.get("STREAMLIT_PASSWORD", ""))): | |
| st.session_state.togetherai_client = OpenAI( | |
| api_key=os.getenv('TOGETHERAI_API_KEY'), | |
| base_url="https://api.together.xyz/v1" | |
| ) | |
| st.session_state.openai_client = OpenAI( | |
| api_key=os.getenv('OPENAI_API_KEY') | |
| ) | |
| st.session_state.anthropic_client = Anthropic( | |
| api_key=os.getenv('ANTHROPIC_API_KEY') | |
| ) | |
| genai.configure(api_key=os.environ["GEMINI_API_KEY"]) | |
| st.session_state.api_configured = True | |
| st.success("Successfully configured the API clients with stored keys!") | |
| else: | |
| st.error("Invalid credentials. Please try again or use your own API keys.") | |
| st.session_state.api_configured = False | |
| else: | |
| st.subheader("Enter Your API Keys") | |
| togetherai_key = st.text_input("Together AI API Key", type="password", key="togetherai_key") | |
| openai_key = st.text_input("OpenAI API Key", type="password", key="openai_key") | |
| anthropic_key = st.text_input("Anthropic API Key", type="password", key="anthropic_key") | |
| gemini_key = st.text_input("Gemini API Key", type="password", key="gemini_key") | |
| if st.button("Initialize with the provided keys"): | |
| try: | |
| st.session_state.togetherai_client = OpenAI( | |
| api_key=togetherai_key, | |
| base_url="https://api.together.xyz/v1" | |
| ) | |
| st.session_state.openai_client = OpenAI( | |
| api_key=openai_key | |
| ) | |
| st.session_state.anthropic_client = Anthropic( | |
| api_key=anthropic_key | |
| ) | |
| genai.configure(api_key=gemini_key) | |
| st.session_state.api_configured = True | |
| st.success("Successfully configured the API clients with provided keys!") | |
| except Exception as e: | |
| st.error(f"Error initializing API clients: {str(e)}") | |
| st.session_state.api_configured = False | |
| MAX_CONCURRENT_CALLS = 5 | |
| semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS) | |
| def load_dataset_by_name(dataset_name, split="train"): | |
| dataset_config = DATASETS[dataset_name] | |
| dataset = load_dataset(dataset_config["loader"]) | |
| df = pd.DataFrame(dataset[split]) | |
| df = df[df['choice_type'] == 'single'] | |
| questions = [] | |
| for _, row in df.iterrows(): | |
| options = [row['opa'], row['opb'], row['opc'], row['opd']] | |
| correct_answer = options[row['cop']] | |
| question_dict = { | |
| 'question': row['question'], | |
| 'options': options, | |
| 'correct_answer': correct_answer, | |
| 'subject_name': row['subject_name'], | |
| 'topic_name': row['topic_name'], | |
| 'explanation': row['exp'] | |
| } | |
| questions.append(question_dict) | |
| st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}") | |
| return questions | |
| def get_model_response(question, options, prompt_template, model_name, clients): | |
| with semaphore: | |
| try: | |
| model_config = MODELS[model_name] | |
| options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)]) | |
| prompt = prompt_template.replace("{question}", question).replace("{options}", options_text) | |
| provider = model_config["provider"] | |
| if provider == "togetherai": | |
| response = clients["togetherai"].chat.completions.create( | |
| model=model_config["model_id"], | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| response_text = response.choices[0].message.content.strip() | |
| elif provider == "openai": | |
| response = clients["openai"].chat.completions.create( | |
| model=model_config["model_id"], | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| response_text = response.choices[0].message.content.strip() | |
| elif provider == "anthropic": | |
| response = clients["anthropic"].messages.create( | |
| model=model_config["model_id"], | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=4096 | |
| ) | |
| response_text = response.content[0].text | |
| elif provider == "google": | |
| model = genai.GenerativeModel( | |
| model_name=model_config["model_id"] | |
| ) | |
| chat_session = model.start_chat( | |
| history=[] | |
| ) | |
| response_text = chat_session.send_message(prompt).text | |
| json_match = re.search(r'\{.*\}', response_text, re.DOTALL) | |
| if not json_match: | |
| return f"Error: Invalid response format", response_text | |
| json_response = json.loads(json_match.group(0)) | |
| answer = json_response.get('answer', '').strip() | |
| answer = re.sub(r'^[A-D]\.\s*', '', answer) | |
| if not any(answer.lower() == opt.lower() for opt in options): | |
| return f"Error: Answer '{answer}' does not match any options", response_text | |
| return answer, response_text | |
| except Exception as e: | |
| return f"Error: {str(e)}", str(e) | |
| def evaluate_response(model_response, correct_answer): | |
| if model_response.startswith("Error:"): | |
| return False | |
| is_correct = model_response.lower().strip() == correct_answer.lower().strip() | |
| return is_correct | |
| def process_single_evaluation(question, prompt_template, model_name, clients): | |
| answer, response_text = get_model_response( | |
| question['question'], | |
| question['options'], | |
| prompt_template, | |
| model_name, | |
| clients | |
| ) | |
| is_correct = evaluate_response(answer, question['correct_answer']) | |
| return { | |
| 'question': question['question'], | |
| 'options': question['options'], | |
| 'model_response': answer, | |
| 'raw_llm_response': response_text, | |
| 'prompt_sent': prompt_template, | |
| 'correct_answer': question['correct_answer'], | |
| 'subject': question['subject_name'], | |
| 'is_correct': is_correct, | |
| 'explanation': question['explanation'], | |
| 'model_name': model_name | |
| } | |
| def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback, clients): | |
| results = [] | |
| total_iterations = len(models_to_evaluate) * len(questions) | |
| current_iteration = 0 | |
| with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor: | |
| future_to_params = {} | |
| for model_name in models_to_evaluate: | |
| for question in questions: | |
| future = executor.submit(process_single_evaluation, question, prompt_template, model_name, clients) | |
| future_to_params[future] = (model_name, question) | |
| for future in as_completed(future_to_params): | |
| result = future.result() | |
| results.append(result) | |
| current_iteration += 1 | |
| progress_callback(current_iteration, total_iterations) | |
| return results | |
| def main(): | |
| st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide") | |
| initialize_session_state() | |
| setup_api_clients() | |
| if not st.session_state.api_configured: | |
| st.warning("Please configure API keys in the sidebar to proceed") | |
| st.stop() | |
| if 'all_results' not in st.session_state: | |
| st.session_state.all_results = {} | |
| if 'detailed_model' not in st.session_state: | |
| st.session_state.detailed_model = None | |
| if 'detailed_dataset' not in st.session_state: | |
| st.session_state.detailed_dataset = None | |
| if 'last_evaluated_dataset' not in st.session_state: | |
| st.session_state.last_evaluated_dataset = None | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| selected_dataset = st.selectbox( | |
| "Select Dataset", | |
| options=list(DATASETS.keys()), | |
| help="Choose the dataset to evaluate on" | |
| ) | |
| with col2: | |
| selected_model = st.multiselect( | |
| "Select Model(s)", | |
| options=list(MODELS.keys()), | |
| default=[list(MODELS.keys())[0]], | |
| help="Choose one or more models to evaluate." | |
| ) | |
| models_to_evaluate = selected_model | |
| default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question. | |
| Question: {question} | |
| Options: | |
| {options} | |
| ## Output Format: | |
| Please provide your answer in JSON format that contains an "answer" field. | |
| You may include any additional fields in your JSON response that you find relevant, such as: | |
| - "choice reasoning": your detailed reasoning | |
| - "elimination reasoning": why you ruled out other options | |
| Example response format: | |
| { | |
| "answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)", | |
| "choice reasoning": "your detailed reasoning here", | |
| "elimination reasoning": "why you ruled out other options" | |
| } | |
| Important: | |
| - Only the "answer" field will be used for evaluation | |
| - Ensure your response is in valid JSON format''' | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| prompt_template = st.text_area( | |
| "Customize Prompt Template", | |
| default_prompt, | |
| height=400, | |
| help="The below prompt is editable. Please feel free to edit it before your run." | |
| ) | |
| with col2: | |
| st.markdown(""" | |
| ### Prompt Variables | |
| - `{question}`: The medical question | |
| - `{options}`: The multiple choice options | |
| """) | |
| with st.spinner("Loading dataset..."): | |
| questions = load_dataset_by_name(selected_dataset) | |
| subjects = sorted(list(set(q['subject_name'] for q in questions))) | |
| selected_subject = st.selectbox("Filter by subject", ["All"] + subjects) | |
| if selected_subject != "All": | |
| questions = [q for q in questions if q['subject_name'] == selected_subject] | |
| num_questions = st.number_input("Number of questions to evaluate", 1, len(questions)) | |
| if st.button("Start Evaluation"): | |
| with st.spinner("Starting evaluation..."): | |
| selected_questions = questions[:num_questions] | |
| # Create a clients dictionary | |
| clients = { | |
| "togetherai": st.session_state["togetherai_client"], | |
| "openai": st.session_state["openai_client"], | |
| "anthropic": st.session_state["anthropic_client"] | |
| } | |
| progress_container = st.container() | |
| progress_bar = progress_container.progress(0) | |
| status_text = progress_container.empty() | |
| def update_progress(current, total): | |
| progress = current / total | |
| progress_bar.progress(progress) | |
| status_text.text(f"Progress: {current}/{total} evaluations completed") | |
| results = process_evaluations_concurrently( | |
| selected_questions, | |
| prompt_template, | |
| models_to_evaluate, | |
| update_progress, | |
| clients | |
| ) | |
| all_results = {} | |
| for result in results: | |
| model = result.pop('model_name') | |
| if model not in all_results: | |
| all_results[model] = [] | |
| all_results[model].append(result) | |
| st.session_state.all_results = all_results | |
| st.session_state.last_evaluated_dataset = selected_dataset | |
| if st.session_state.detailed_model is None and all_results: | |
| st.session_state.detailed_model = list(all_results.keys())[0] | |
| if st.session_state.detailed_dataset is None: | |
| st.session_state.detailed_dataset = selected_dataset | |
| st.success("Evaluation completed!") | |
| st.rerun() | |
| if st.session_state.all_results: | |
| st.subheader("Evaluation Results") | |
| model_metrics = {} | |
| for model_name, results in st.session_state.all_results.items(): | |
| df = pd.DataFrame(results) | |
| metrics = { | |
| 'Accuracy': df['is_correct'].mean(), | |
| } | |
| model_metrics[model_name] = metrics | |
| metrics_df = pd.DataFrame(model_metrics).T | |
| st.subheader("Model Performance Comparison") | |
| accuracy_chart = alt.Chart( | |
| metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy']) | |
| ).mark_bar().encode( | |
| x=alt.X('index:N', title=None, axis=None), | |
| y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])), | |
| color=alt.Color('index:N', scale=alt.Scale(scheme='blues')), | |
| tooltip=['index:N', 'value:Q'] | |
| ).properties( | |
| height=300, | |
| title={ | |
| "text": "Model Accuracy", | |
| "baseline": "bottom", | |
| "orient": "bottom", | |
| "dy": 20 | |
| } | |
| ) | |
| st.altair_chart(accuracy_chart, use_container_width=True) | |
| if st.session_state.all_results: | |
| st.subheader("Detailed Results") | |
| def update_model(): | |
| st.session_state.detailed_model = st.session_state.model_select | |
| def update_dataset(): | |
| st.session_state.detailed_dataset = st.session_state.dataset_select | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| selected_model_details = st.selectbox( | |
| "Select model", | |
| options=list(st.session_state.all_results.keys()), | |
| key="model_select", | |
| on_change=update_model, | |
| index=list(st.session_state.all_results.keys()).index(st.session_state.detailed_model) | |
| if st.session_state.detailed_model in st.session_state.all_results else 0 | |
| ) | |
| with col2: | |
| selected_dataset_details = st.selectbox( | |
| "Select dataset", | |
| options=[st.session_state.last_evaluated_dataset], | |
| key="dataset_select", | |
| on_change=update_dataset | |
| ) | |
| if selected_model_details in st.session_state.all_results: | |
| results = st.session_state.all_results[selected_model_details] | |
| df = pd.DataFrame(results) | |
| accuracy = df['is_correct'].mean() | |
| st.metric("Accuracy", f"{accuracy:.2%}") | |
| for idx, result in enumerate(results): | |
| with st.expander(f"Question {idx + 1} - {result['subject']}"): | |
| st.write("**Question:**", result['question']) | |
| st.write("**Options:**") | |
| for i, opt in enumerate(result['options']): | |
| st.write(f"{chr(65+i)}. {opt}") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("**Prompt Used:**") | |
| st.code(result['prompt_sent']) | |
| with col2: | |
| st.write("**Raw Response:**") | |
| st.code(result['raw_llm_response']) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("**Correct Answer:**", result['correct_answer']) | |
| st.write("**Model Answer:**", result['model_response']) | |
| with col2: | |
| if result['is_correct']: | |
| st.success("Correct!") | |
| else: | |
| st.error("Incorrect") | |
| st.write("**Explanation:**", result['explanation']) | |
| else: | |
| st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.") | |
| st.markdown("---") | |
| all_data = [] | |
| for model_name, results in st.session_state.all_results.items(): | |
| for question_idx, result in enumerate(results): | |
| row = { | |
| 'dataset': st.session_state.last_evaluated_dataset, | |
| 'model': model_name, | |
| 'question': result['question'], | |
| 'correct_answer': result['correct_answer'], | |
| 'subject': result['subject'], | |
| 'options': ' | '.join(result['options']), | |
| 'model_response': result['model_response'], | |
| 'is_correct': result['is_correct'], | |
| 'explanation': result['explanation'] | |
| } | |
| all_data.append(row) | |
| complete_df = pd.DataFrame(all_data) | |
| csv = complete_df.to_csv(index=False) | |
| st.download_button( | |
| label="Download All Results as CSV", | |
| data=csv, | |
| file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv", | |
| mime="text/csv", | |
| key="download_all_results" | |
| ) | |
| if __name__ == "__main__": | |
| main() |