ehagey commited on
Commit
01cea67
·
verified ·
1 Parent(s): 697c4ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -67
app.py CHANGED
@@ -7,6 +7,8 @@ import json
7
  import re
8
  import os
9
  from config import DATASETS, MODELS
 
 
10
 
11
  load_dotenv()
12
  client = Together(api_key=os.getenv('TOGETHERAI_API_KEY'))
@@ -46,7 +48,6 @@ def get_model_response(question, options, prompt_template, model_name):
46
  model=model_config["model_id"],
47
  messages=[{"role": "user", "content": prompt}]
48
  )
49
-
50
  response_text = response.choices[0].message.content.strip()
51
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
52
  json_response = json.loads(json_match.group(0))
@@ -57,7 +58,6 @@ def get_model_response(question, options, prompt_template, model_name):
57
  return f"Error: Answer '{answer}' does not match any options"
58
 
59
  return answer
60
-
61
  except Exception as e:
62
  return f"Error: {str(e)}"
63
 
@@ -70,6 +70,14 @@ def main():
70
  st.set_page_config(page_title="LLM Benchmarking in Healthcare", layout="wide")
71
  st.title("LLM Benchmarking in Healthcare")
72
 
 
 
 
 
 
 
 
 
73
  col1, col2 = st.columns(2)
74
  with col1:
75
  selected_dataset = st.selectbox(
@@ -78,12 +86,15 @@ def main():
78
  help="Choose the dataset to evaluate on"
79
  )
80
  with col2:
81
- selected_model = st.selectbox(
82
- "Select Model",
83
  options=list(MODELS.keys()),
84
- help="Choose the model to evaluate"
 
85
  )
86
 
 
 
87
  default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
88
 
89
  Question: {question}
@@ -144,80 +155,165 @@ Important:
144
  st.error("Please set the TOGETHERAI_API_KEY in your .env file")
145
  return
146
 
147
- progress_bar = st.progress(0)
148
- status_text = st.empty()
 
 
 
 
149
  results_container = st.container()
 
150
 
151
- results = []
152
- for i in range(num_questions):
153
- question = questions[i]
154
- progress = (i + 1) / num_questions
155
- progress_bar.progress(progress)
156
- status_text.text(f"Evaluating question {i + 1}/{num_questions}")
157
-
158
- model_response = get_model_response(
159
- question['question'],
160
- question['options'],
161
- prompt_template,
162
- selected_model
163
- )
164
 
165
- options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(question['options'])])
166
- formatted_prompt = prompt_template.replace("{question}", question['question']).replace("{options}", options_text)
167
- raw_response = client.chat.completions.create(
168
- model=MODELS[selected_model]["model_id"],
169
- messages=[{"role": "user", "content": formatted_prompt}]
170
- ).choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- is_correct = evaluate_response(model_response, question['correct_answer'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- results.append({
175
- 'question': question['question'],
176
- 'options': question['options'],
177
- 'model_response': model_response,
178
- 'raw_llm_response': raw_response,
179
- 'prompt_sent': formatted_prompt,
180
- 'correct_answer': question['correct_answer'],
181
- 'subject': question['subject_name'],
182
- 'is_correct': is_correct,
183
- 'explanation': question['explanation']
184
- })
185
-
186
- with results_container:
187
- st.subheader("Evaluation Results")
 
 
 
 
 
 
 
 
 
 
188
  df = pd.DataFrame(results)
189
  accuracy = df['is_correct'].mean()
 
190
  st.metric("Accuracy", f"{accuracy:.2%}")
191
 
192
  for idx, result in enumerate(results):
193
- st.markdown("---")
194
- st.subheader(f"Question {idx + 1} - {result['subject']}")
195
-
196
- st.write("Question:", result['question'])
197
- st.write("Options:")
198
- for i, opt in enumerate(result['options']):
199
- st.write(f"{chr(65+i)}. {opt}")
200
-
201
- col1, col2 = st.columns(2)
202
- with col1:
203
- with st.expander("Show Prompt"):
204
  st.code(result['prompt_sent'])
205
- with col2:
206
- with st.expander("Show Raw Response"):
207
  st.code(result['raw_llm_response'])
208
-
209
- col1, col2 = st.columns(2)
210
- with col1:
211
- st.write("Correct Answer:", result['correct_answer'])
212
- st.write("Model Answer:", result['model_response'])
213
- with col2:
214
- if result['is_correct']:
215
- st.success("Correct!")
216
- else:
217
- st.error("Incorrect")
218
-
219
- with st.expander("Show Explanation"):
220
- st.write(result['explanation'])
 
221
 
222
  if __name__ == "__main__":
223
  main()
 
7
  import re
8
  import os
9
  from config import DATASETS, MODELS
10
+ import matplotlib.pyplot as plt
11
+ import altair as alt
12
 
13
  load_dotenv()
14
  client = Together(api_key=os.getenv('TOGETHERAI_API_KEY'))
 
48
  model=model_config["model_id"],
49
  messages=[{"role": "user", "content": prompt}]
50
  )
 
51
  response_text = response.choices[0].message.content.strip()
52
  json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
53
  json_response = json.loads(json_match.group(0))
 
58
  return f"Error: Answer '{answer}' does not match any options"
59
 
60
  return answer
 
61
  except Exception as e:
62
  return f"Error: {str(e)}"
63
 
 
70
  st.set_page_config(page_title="LLM Benchmarking in Healthcare", layout="wide")
71
  st.title("LLM Benchmarking in Healthcare")
72
 
73
+ if 'all_results' not in st.session_state:
74
+ st.session_state.all_results = {}
75
+ if 'detailed_model' not in st.session_state:
76
+ st.session_state.detailed_model = None
77
+ if 'detailed_dataset' not in st.session_state:
78
+ st.session_state.detailed_dataset = None
79
+ if 'last_evaluated_dataset' not in st.session_state:
80
+ st.session_state.last_evaluated_dataset = None
81
  col1, col2 = st.columns(2)
82
  with col1:
83
  selected_dataset = st.selectbox(
 
86
  help="Choose the dataset to evaluate on"
87
  )
88
  with col2:
89
+ selected_model = st.multiselect(
90
+ "Select Model(s)",
91
  options=list(MODELS.keys()),
92
+ default=[list(MODELS.keys())[0]],
93
+ help="Choose one or more models to evaluate."
94
  )
95
 
96
+ models_to_evaluate = selected_model
97
+
98
  default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
99
 
100
  Question: {question}
 
155
  st.error("Please set the TOGETHERAI_API_KEY in your .env file")
156
  return
157
 
158
+ progress_container = st.container()
159
+ with progress_container:
160
+ progress_bar = st.progress(0)
161
+ status_text = st.empty()
162
+ substatus_text = st.empty()
163
+
164
  results_container = st.container()
165
+ all_results = {}
166
 
167
+ total_iterations = len(models_to_evaluate) * num_questions
168
+ current_iteration = 0
169
+
170
+ for model_name in models_to_evaluate:
171
+ substatus_text.markdown(f"<small>Evaluating model: {model_name} on {selected_dataset}</small>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
172
 
173
+ results = []
174
+ for i in range(num_questions):
175
+ question = questions[i]
176
+ current_iteration += 1
177
+ progress = current_iteration / total_iterations
178
+ progress_bar.progress(progress)
179
+ status_text.text(f"Progress: {current_iteration}/{total_iterations} evaluations")
180
+
181
+ model_response = get_model_response(
182
+ question['question'],
183
+ question['options'],
184
+ prompt_template,
185
+ model_name
186
+ )
187
+
188
+ options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(question['options'])])
189
+ formatted_prompt = prompt_template.replace("{question}", question['question']).replace("{options}", options_text)
190
+ raw_response = client.chat.completions.create(
191
+ model=MODELS[model_name]["model_id"],
192
+ messages=[{"role": "user", "content": formatted_prompt}],
193
+ temperature=0.7
194
+ ).choices[0].message.content.strip()
195
+
196
+ is_correct = evaluate_response(model_response, question['correct_answer'])
197
+
198
+ results.append({
199
+ 'question': question['question'],
200
+ 'options': question['options'],
201
+ 'model_response': model_response,
202
+ 'raw_llm_response': raw_response,
203
+ 'prompt_sent': formatted_prompt,
204
+ 'correct_answer': question['correct_answer'],
205
+ 'subject': question['subject_name'],
206
+ 'is_correct': is_correct,
207
+ 'explanation': question['explanation']
208
+ })
209
 
210
+ all_results[model_name] = results
211
+ st.session_state.all_results = all_results
212
+ st.session_state.last_evaluated_dataset = selected_dataset
213
+
214
+
215
+ if st.session_state.detailed_model is None and all_results:
216
+ st.session_state.detailed_model = list(all_results.keys())[0]
217
+ if st.session_state.detailed_dataset is None:
218
+ st.session_state.detailed_dataset = selected_dataset
219
+
220
+ st.rerun()
221
+
222
+ if st.session_state.all_results:
223
+ st.subheader("Evaluation Results")
224
+
225
+ model_metrics = {}
226
+ for model_name, results in st.session_state.all_results.items():
227
+ df = pd.DataFrame(results)
228
+ metrics = {
229
+ 'Accuracy': df['is_correct'].mean(),
230
+ }
231
+ model_metrics[model_name] = metrics
232
+
233
+ metrics_df = pd.DataFrame(model_metrics).T
234
+
235
+ st.subheader("Model Performance Comparison")
236
+
237
+ accuracy_chart = alt.Chart(
238
+ metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy'])
239
+ ).mark_bar().encode(
240
+ x=alt.X('index:N', title=None, axis=None),
241
+ y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
242
+ color='index:N'
243
+ ).properties(
244
+ height=300,
245
+ title={
246
+ "text": "Model Accuracy",
247
+ "baseline": "bottom",
248
+ "orient": "bottom",
249
+ "dy": 20
250
+ }
251
+ )
252
+ st.altair_chart(accuracy_chart, use_container_width=True)
253
+
254
+ if st.session_state.all_results:
255
+ st.subheader("Detailed Results")
256
+
257
+ def update_model():
258
+ st.session_state.detailed_model = st.session_state.model_select
259
 
260
+ def update_dataset():
261
+ st.session_state.detailed_dataset = st.session_state.dataset_select
262
+
263
+ col1, col2 = st.columns(2)
264
+ with col1:
265
+ selected_model_details = st.selectbox(
266
+ "Select model",
267
+ options=list(st.session_state.all_results.keys()),
268
+ key="model_select",
269
+ on_change=update_model,
270
+ index=list(st.session_state.all_results.keys()).index(st.session_state.detailed_model)
271
+ if st.session_state.detailed_model in st.session_state.all_results else 0
272
+ )
273
+
274
+ with col2:
275
+ selected_dataset_details = st.selectbox(
276
+ "Select dataset",
277
+ options=[st.session_state.last_evaluated_dataset],
278
+ key="dataset_select",
279
+ on_change=update_dataset
280
+ )
281
+
282
+ if selected_model_details in st.session_state.all_results:
283
+ results = st.session_state.all_results[selected_model_details]
284
  df = pd.DataFrame(results)
285
  accuracy = df['is_correct'].mean()
286
+
287
  st.metric("Accuracy", f"{accuracy:.2%}")
288
 
289
  for idx, result in enumerate(results):
290
+ with st.expander(f"Question {idx + 1} - {result['subject']}"):
291
+ st.write("Question:", result['question'])
292
+ st.write("Options:")
293
+ for i, opt in enumerate(result['options']):
294
+ st.write(f"{chr(65+i)}. {opt}")
295
+
296
+ col1, col2 = st.columns(2)
297
+ with col1:
298
+ st.write("Prompt Used:")
 
 
299
  st.code(result['prompt_sent'])
300
+ with col2:
301
+ st.write("Raw Response:")
302
  st.code(result['raw_llm_response'])
303
+
304
+ col1, col2 = st.columns(2)
305
+ with col1:
306
+ st.write("Correct Answer:", result['correct_answer'])
307
+ st.write("Model Answer:", result['model_response'])
308
+ with col2:
309
+ if result['is_correct']:
310
+ st.success("Correct!")
311
+ else:
312
+ st.error("Incorrect")
313
+
314
+ st.write("Explanation:", result['explanation'])
315
+ else:
316
+ st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")
317
 
318
  if __name__ == "__main__":
319
  main()