boardllm / src /evaluator.py
melmoheb's picture
Upload folder using huggingface_hub
2247e66 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class AnswerEvaluator:
"""Evaluates user answers against expected answers using an LLM."""
def __init__(self, model_id="meta-llama/Llama-3.2-3B-Instruct"):
print(f"Initializing AnswerEvaluator with model: {model_id}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print("Set pad_token to eos_token")
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
self.model.eval()
self.device = self.model.device
print(f"AnswerEvaluator model loaded successfully on device: {self.device}")
except Exception as e:
print(f"Error initializing AnswerEvaluator model {model_id}: {e}")
raise
def evaluate_answer(self, user_answer, expected_answer, clinical_context=None):
"""
Compare user answer to expected answer and provide feedback
Args:
user_answer: Examinee's response
expected_answer: Model answer from the dataset
clinical_context: Optional clinical context to consider
Returns:
Feedback string
"""
context_str = f"Clinical context: {clinical_context}\n\n" if clinical_context else ""
prompt = f"""<s>[INST] You are acting as an expert examiner for the American Board of Surgery (ABS) oral board exam. You are evaluating a general surgery resident’s answer to a clinical question. \n
Compare the answer provided by the residents to the correct expected answer, which I will provide you with. \n
Use the grading rubric below to assess their response:
[RUBRIC]
- Correct: Resident includes all major points and clinical reasoning aligns closely with the expected answer.
- Partially Correct: Resident includes some key points but omits others, or reasoning is partially flawed.
- Incorrect: Resident misses most key points or demonstrates incorrect reasoning.
{context_str}Here is the model answer that contains the key points expected from the resident:
{expected_answer}
Now, here is the resident’s actual response:
{user_answer}
Evaluate the resident’s response based **only** on the expected answer above. Do not rely on external knowledge or previous responses.
Focus your evaluation on:
1. Which key points were mentioned vs. missed
2. The accuracy and clarity of the clinical reasoning
3. Any major omissions or misunderstandings
Start your output with:
ASSESSMENT: [Correct / Partially Correct / Incorrect]
Then write 1–2 clear, specific sentences explaining how the resident’s response compares to the expected answer.
[EXAMPLE 1]
Expected answer:
"The differential diagnosis includes acute appendicitis, mesenteric adenitis, gastroenteritis, UTI, and testicular torsion."
Resident’s response:
"My top concern is appendicitis, but I’d also consider things like gastroenteritis or maybe even kidney stones."
ASSESSMENT: Partially Correct
The resident mentioned appendicitis and gastroenteritis but missed several other expected differentials like UTI, testicular torsion, and mesenteric adenitis.
[EXAMPLE 2]
Expected answer:
"Initial labs should include CBC, CMP, lipase, and abdominal ultrasound to assess for gallstones."
Resident’s response:
"I’d start with a full workup including CBC, liver enzymes, lipase, and an abdominal ultrasound."
ASSESSMENT: Correct
The resident included all key labs and the correct imaging modality. Their reasoning aligns well with the expected answer.
[/INST]</s>"""
try:
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) # Added truncation
with torch.no_grad():
# Generate feedback using the model
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature = 0.2,
pad_token_id=self.tokenizer.eos_token_id # Ensure pad token ID is set
)
prompt_length_tokens = inputs.input_ids.shape[1]
generated_ids = outputs[0][prompt_length_tokens:]
feedback = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return feedback
except Exception as e:
print(f"Error during LLM evaluation: {e}")
return "Error: Could not generate feedback."