Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| from together import Together | |
| from dotenv import load_dotenv | |
| from datasets import load_dataset | |
| import json | |
| import re | |
| import os | |
| from config import DATASETS, MODELS | |
| load_dotenv() | |
| client = Together(api_key=os.getenv('TOGETHERAI_API_KEY')) | |
| def load_dataset_by_name(dataset_name, split="train"): | |
| dataset_config = DATASETS[dataset_name] | |
| dataset = load_dataset(dataset_config["loader"]) | |
| df = pd.DataFrame(dataset[split]) | |
| df = df[df['choice_type'] == 'single'] | |
| questions = [] | |
| for _, row in df.iterrows(): | |
| options = [row['opa'], row['opb'], row['opc'], row['opd']] | |
| correct_answer = options[row['cop']] | |
| question_dict = { | |
| 'question': row['question'], | |
| 'options': options, | |
| 'correct_answer': correct_answer, | |
| 'subject_name': row['subject_name'], | |
| 'topic_name': row['topic_name'], | |
| 'explanation': row['exp'] | |
| } | |
| questions.append(question_dict) | |
| st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}") | |
| return questions | |
| def get_model_response(question, options, prompt_template, model_name): | |
| try: | |
| model_config = MODELS[model_name] | |
| options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)]) | |
| prompt = prompt_template.replace("{question}", question).replace("{options}", options_text) | |
| response = client.chat.completions.create( | |
| model=model_config["model_id"], | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| response_text = response.choices[0].message.content.strip() | |
| json_match = re.search(r'\{.*\}', response_text, re.DOTALL) | |
| json_response = json.loads(json_match.group(0)) | |
| answer = json_response['answer'].strip() | |
| answer = re.sub(r'^[A-D]\.\s*', '', answer) | |
| if not any(answer.lower() == opt.lower() for opt in options): | |
| return f"Error: Answer '{answer}' does not match any options" | |
| return answer | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def evaluate_response(model_response, correct_answer): | |
| if model_response.startswith("Error:"): | |
| return False | |
| return model_response.lower().strip() == correct_answer.lower().strip() | |
| def main(): | |
| st.set_page_config(page_title="Medical LLM Evaluation", layout="wide") | |
| st.title("Medical LLM Evaluation") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| selected_dataset = st.selectbox( | |
| "Select Dataset", | |
| options=list(DATASETS.keys()), | |
| help="Choose the dataset to evaluate on" | |
| ) | |
| with col2: | |
| selected_model = st.selectbox( | |
| "Select Model", | |
| options=list(MODELS.keys()), | |
| help="Choose the model to evaluate" | |
| ) | |
| default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question. | |
| Question: {question} | |
| Options: | |
| {options} | |
| ## Output Format: | |
| Please provide you answer in JSON format that contains an "answer" field. | |
| You may include any additional fields in your JSON response that you find relevant, such as: | |
| - "answer": the option you selected | |
| - "choice reasoning": your detailed reasoning | |
| - "elimination reasoning": why you ruled out other options | |
| Example response format: | |
| { | |
| "answer": "exact option text here", | |
| "choice reasoning": "your detailed reasoning here", | |
| "elimination reasoning": "why you ruled out other options" | |
| } | |
| Important: | |
| - Only the "answer" field will be used for evaluation | |
| - Ensure your response is in valid JSON format''' | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| prompt_template = st.text_area( | |
| "Customize Prompt Template", | |
| default_prompt, | |
| height=400, | |
| help="The below prompt is editable. Please feel free to edit it before your run." | |
| ) | |
| with col2: | |
| st.markdown(""" | |
| ### Prompt Variables | |
| - `{question}`: The medical question | |
| - `{options}`: The multiple choice options | |
| """) | |
| with st.spinner("Loading dataset..."): | |
| questions = load_dataset_by_name(selected_dataset) | |
| if not questions: | |
| st.error("No questions were loaded successfully.") | |
| return | |
| subjects = list(set(q['subject_name'] for q in questions)) | |
| selected_subject = st.selectbox("Filter by subject", ["All"] + subjects) | |
| if selected_subject != "All": | |
| questions = [q for q in questions if q['subject_name'] == selected_subject] | |
| num_questions = st.number_input("Number of questions to evaluate", 1, len(questions)) | |
| if st.button("Start Evaluation"): | |
| if not os.getenv('TOGETHERAI_API_KEY'): | |
| st.error("Please set the TOGETHERAI_API_KEY in your .env file") | |
| return | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| results_container = st.container() | |
| results = [] | |
| for i in range(num_questions): | |
| question = questions[i] | |
| progress = (i + 1) / num_questions | |
| progress_bar.progress(progress) | |
| status_text.text(f"Evaluating question {i + 1}/{num_questions}") | |
| model_response = get_model_response( | |
| question['question'], | |
| question['options'], | |
| prompt_template, | |
| selected_model | |
| ) | |
| options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(question['options'])]) | |
| formatted_prompt = prompt_template.replace("{question}", question['question']).replace("{options}", options_text) | |
| raw_response = client.chat.completions.create( | |
| model=MODELS[selected_model]["model_id"], | |
| messages=[{"role": "user", "content": formatted_prompt}] | |
| ).choices[0].message.content.strip() | |
| is_correct = evaluate_response(model_response, question['correct_answer']) | |
| results.append({ | |
| 'question': question['question'], | |
| 'options': question['options'], | |
| 'model_response': model_response, | |
| 'raw_llm_response': raw_response, | |
| 'prompt_sent': formatted_prompt, | |
| 'correct_answer': question['correct_answer'], | |
| 'subject': question['subject_name'], | |
| 'is_correct': is_correct, | |
| 'explanation': question['explanation'] | |
| }) | |
| with results_container: | |
| st.subheader("Evaluation Results") | |
| df = pd.DataFrame(results) | |
| accuracy = df['is_correct'].mean() | |
| st.metric("Accuracy", f"{accuracy:.2%}") | |
| for idx, result in enumerate(results): | |
| st.markdown("---") | |
| st.subheader(f"Question {idx + 1} - {result['subject']}") | |
| st.write("Question:", result['question']) | |
| st.write("Options:") | |
| for i, opt in enumerate(result['options']): | |
| st.write(f"{chr(65+i)}. {opt}") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| with st.expander("Show Prompt"): | |
| st.code(result['prompt_sent']) | |
| with col2: | |
| with st.expander("Show Raw Response"): | |
| st.code(result['raw_llm_response']) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("Correct Answer:", result['correct_answer']) | |
| st.write("Model Answer:", result['model_response']) | |
| with col2: | |
| if result['is_correct']: | |
| st.success("Correct!") | |
| else: | |
| st.error("Incorrect") | |
| with st.expander("Show Explanation"): | |
| st.write(result['explanation']) | |
| if __name__ == "__main__": | |
| main() |