medical-benchmark-scripts / zero_shot_benchmark.py
airevartis's picture
Upload zero_shot_benchmark.py with huggingface_hub
aa965c5 verified
#!/usr/bin/env python3
"""
Zero-shot benchmark evaluation on Hugging Face infrastructure
"""
import torch
import json
import yaml
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
BitsAndBytesConfig
)
from datasets import load_dataset
import numpy as np
from typing import Dict, List, Tuple
import logging
import re
from pathlib import Path
import os
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class HFZeroShotBenchmark:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Model configurations
self.models = {
"biomistral_7b": {
"name": "BioMistral/BioMistral-7B",
"type": "causal_lm"
},
"qwen3_7b": {
"name": "Qwen/Qwen2.5-7B-Instruct",
"type": "causal_lm"
},
"meditron_7b": {
"name": "epfl-llm/meditron-7b",
"type": "causal_lm"
},
"internist_7b": {
"name": "internistai/internist-7b",
"type": "causal_lm"
}
}
# Quantization config for memory efficiency
self.quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
def load_model(self, model_name: str, model_config: Dict) -> Tuple:
"""Load model and tokenizer"""
logger.info(f"Loading model: {model_name}")
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_config['name'],
trust_remote_code=True
)
# Add padding token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model with quantization for memory efficiency
model = AutoModelForCausalLM.from_pretrained(
model_config['name'],
quantization_config=self.quantization_config if self.device == "cuda" else None,
device_map="auto" if self.device == "cuda" else None,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
trust_remote_code=True
)
logger.info(f"Successfully loaded {model_name}")
return model, tokenizer
except Exception as e:
logger.error(f"Failed to load {model_name}: {e}")
return None, None
def create_prompt(self, question: str, options: List[str], model_name: str) -> str:
"""Create prompt for different model types"""
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
if "qwen" in model_name.lower():
return f"""<|im_start|>user
{question}
{options_text}
Please select the correct answer (A, B, C, D, or E).<|im_end|>
<|im_start|>assistant
The correct answer is"""
elif "mistral" in model_name.lower() or "biomistral" in model_name.lower():
return f"""<s>[INST] {question}
{options_text}
Please select the correct answer (A, B, C, D, or E). [/INST] The correct answer is"""
else:
# Generic format
return f"""Question: {question}
{options_text}
Answer:"""
def extract_answer(self, text: str) -> str:
"""Extract answer from model output"""
# Look for patterns like "A", "B", "C", "D", "E"
patterns = [
r'[Tt]he correct answer is ([A-E])',
r'[Aa]nswer: ([A-E])',
r'([A-E])\.',
r'^([A-E])\s*$'
]
for pattern in patterns:
match = re.search(pattern, text)
if match:
return match.group(1)
# If no clear pattern, return the first A-E found
match = re.search(r'([A-E])', text)
if match:
return match.group(1)
return "A" # Default fallback
def evaluate_model(self, model_name: str, model_config: Dict, test_dataset) -> Dict:
"""Evaluate a single model on the test dataset"""
logger.info(f"Evaluating {model_name}")
model, tokenizer = self.load_model(model_name, model_config)
if model is None or tokenizer is None:
return {"error": f"Failed to load {model_name}"}
# Create generation pipeline
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=50,
temperature=0.1,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
results = []
correct = 0
total = len(test_dataset)
logger.info(f"Running evaluation on {total} examples")
for i, example in enumerate(test_dataset):
try:
# Create prompt
prompt = self.create_prompt(
example['question'],
example['options'],
model_name
)
# Generate response
response = generator(prompt, return_full_text=False)
generated_text = response[0]['generated_text']
# Extract answer
predicted_answer = self.extract_answer(generated_text)
true_answer = example['answer']
is_correct = predicted_answer == true_answer
if is_correct:
correct += 1
results.append({
'question_id': i,
'question': example['question'],
'options': example['options'],
'true_answer': true_answer,
'predicted_answer': predicted_answer,
'generated_text': generated_text,
'is_correct': is_correct
})
except Exception as e:
logger.error(f"Error processing example {i}: {e}")
results.append({
'question_id': i,
'error': str(e),
'is_correct': False
})
# Calculate metrics
accuracy = correct / total if total > 0 else 0
# Calculate per-option accuracy
option_accuracies = {}
for option in ['A', 'B', 'C', 'D', 'E']:
option_correct = sum(1 for r in results if r.get('true_answer') == option and r.get('is_correct', False))
option_total = sum(1 for r in results if r.get('true_answer') == option)
option_accuracies[option] = option_correct / option_total if option_total > 0 else 0
metrics = {
'model_name': model_name,
'total_examples': total,
'correct_predictions': correct,
'accuracy': accuracy,
'option_accuracies': option_accuracies
}
logger.info(f"{model_name} - Accuracy: {accuracy:.4f}")
# Clean up memory
del model, tokenizer, generator
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return metrics
def run_benchmark(self, test_dataset) -> Dict:
"""Run benchmark on all models"""
results = {}
for model_name, model_config in self.models.items():
logger.info(f"Starting evaluation for {model_name}")
results[model_name] = self.evaluate_model(model_name, model_config, test_dataset)
return results
def save_results(self, results: Dict, output_path: str = "/tmp/zero_shot_results.json"):
"""Save evaluation results"""
# Remove full results for JSON serialization (keep only metrics)
serializable_results = {}
for model_name, result in results.items():
if 'error' in result:
serializable_results[model_name] = result
else:
serializable_results[model_name] = {
'model_name': result['model_name'],
'total_examples': result['total_examples'],
'correct_predictions': result['correct_predictions'],
'accuracy': result['accuracy'],
'option_accuracies': result['option_accuracies']
}
with open(output_path, 'w') as f:
json.dump(serializable_results, f, indent=2)
logger.info(f"Results saved to {output_path}")
return output_path
def main():
"""Main function for HF job"""
logger.info("Starting zero-shot benchmark on Hugging Face infrastructure")
# Load MedQA dataset
logger.info("Loading MedQA dataset...")
try:
# Try different dataset names
dataset_names = ["bigbio/med_qa", "medqa", "medqa_usmle"]
dataset = None
for name in dataset_names:
try:
dataset = load_dataset(name)
logger.info(f"Loaded dataset: {name}")
break
except:
continue
if dataset is None:
logger.error("Could not load MedQA dataset")
return
# Process dataset to standard format
def process_example(example):
# Handle different dataset formats
if 'question' in example:
question = example['question']
elif 'text' in example:
question = example['text']
else:
question = example['input']
# Handle multiple choice options
if 'options' in example:
options = example['options']
elif 'choices' in example:
options = example['choices']
else:
# Create options from available fields
options = []
for i in range(5): # MedQA typically has 5 options
key = f'option_{i}' if f'option_{i}' in example else f'choice_{i}'
if key in example:
options.append(example[key])
# Get answer
if 'answer' in example:
answer = example['answer']
elif 'label' in example:
answer = example['label']
else:
answer = example['output']
return {
'question': question,
'options': options,
'answer': answer
}
# Process test dataset
test_dataset = dataset['test'].map(process_example)
logger.info(f"Processed {len(test_dataset)} test examples")
except Exception as e:
logger.error(f"Error loading dataset: {e}")
return
# Initialize benchmark
benchmark = HFZeroShotBenchmark()
# Run benchmark
logger.info("Starting zero-shot benchmark evaluation...")
results = benchmark.run_benchmark(test_dataset)
# Save results
output_path = benchmark.save_results(results)
# Print summary
print("\n" + "="*50)
print("ZERO-SHOT BENCHMARK RESULTS")
print("="*50)
for model_name, result in results.items():
if 'error' not in result:
print(f"{model_name}: {result['accuracy']:.4f} accuracy")
# Upload results to HF Hub (if configured)
try:
from huggingface_hub import HfApi
api = HfApi()
# Create a repository for results
repo_name = "medical-benchmark-results"
try:
api.create_repo(repo_name, exist_ok=True)
except:
pass
# Upload results
api.upload_file(
path_or_fileobj=output_path,
path_in_repo="zero_shot_benchmark.json",
repo_id=repo_name,
repo_type="dataset"
)
logger.info(f"Results uploaded to {repo_name}/zero_shot_benchmark.json")
except Exception as e:
logger.warning(f"Could not upload results to HF Hub: {e}")
logger.info("Zero-shot benchmark completed!")
if __name__ == "__main__":
main()