File size: 7,607 Bytes
c7ebaa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""
Evaluation module for BioRLHF.

This module provides functionality for evaluating fine-tuned models on
biological reasoning tasks.
"""

import json
from pathlib import Path
from typing import Dict, List, Optional, Union
from dataclasses import dataclass
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from biorlhf.utils.model_utils import get_quantization_config


@dataclass
class EvaluationResult:
    """Results from model evaluation."""

    overall_accuracy: float
    factual_accuracy: float
    reasoning_accuracy: float
    calibration_accuracy: float
    total_questions: int
    correct_answers: int
    detailed_results: List[Dict]


def evaluate_model(
    model_path: str,
    test_questions_path: str,
    base_model: str = "mistralai/Mistral-7B-v0.3",
    use_4bit: bool = True,
    max_new_tokens: int = 512,
    temperature: float = 0.1,
) -> EvaluationResult:
    """
    Evaluate a fine-tuned model on a test set.

    Args:
        model_path: Path to the fine-tuned model.
        test_questions_path: Path to JSON file with test questions.
        base_model: Base model name.
        use_4bit: Use 4-bit quantization.
        max_new_tokens: Maximum tokens to generate.
        temperature: Sampling temperature.

    Returns:
        EvaluationResult with accuracy metrics.
    """
    print(f"Loading model from {model_path}...")

    # Load quantization config
    bnb_config = get_quantization_config() if use_4bit else None

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )

    model = PeftModel.from_pretrained(model, model_path)

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token

    # Load test questions
    with open(test_questions_path, "r") as f:
        test_questions = json.load(f)

    print(f"Evaluating on {len(test_questions)} questions...")

    # Evaluate
    results = []
    category_correct = {"factual": 0, "reasoning": 0, "calibration": 0}
    category_total = {"factual": 0, "reasoning": 0, "calibration": 0}

    for q in test_questions:
        prompt = f"### Instruction:\n{q['question']}\n\n### Response:\n"

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=temperature > 0,
                pad_token_id=tokenizer.pad_token_id,
            )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = response[len(prompt):].strip()

        # Check correctness
        is_correct = _check_answer(response, q.get("expected_answer", ""), q.get("keywords", []))

        category = q.get("category", "factual")
        category_total[category] += 1
        if is_correct:
            category_correct[category] += 1

        results.append({
            "question": q["question"],
            "expected": q.get("expected_answer", ""),
            "response": response,
            "correct": is_correct,
            "category": category,
        })

    # Compute metrics
    total_correct = sum(category_correct.values())
    total_questions = sum(category_total.values())

    return EvaluationResult(
        overall_accuracy=total_correct / total_questions if total_questions > 0 else 0.0,
        factual_accuracy=category_correct["factual"] / category_total["factual"] if category_total["factual"] > 0 else 0.0,
        reasoning_accuracy=category_correct["reasoning"] / category_total["reasoning"] if category_total["reasoning"] > 0 else 0.0,
        calibration_accuracy=category_correct["calibration"] / category_total["calibration"] if category_total["calibration"] > 0 else 0.0,
        total_questions=total_questions,
        correct_answers=total_correct,
        detailed_results=results,
    )


def _check_answer(response: str, expected: str, keywords: List[str]) -> bool:
    """
    Check if a response is correct based on expected answer and keywords.

    Args:
        response: Model's response.
        expected: Expected answer (can be partial).
        keywords: Keywords that should appear in correct response.

    Returns:
        True if answer is considered correct.
    """
    response_lower = response.lower()

    # Check for keywords
    if keywords:
        return all(kw.lower() in response_lower for kw in keywords)

    # Check for expected answer substring
    if expected:
        return expected.lower() in response_lower

    return False


def compute_metrics(results: List[Dict]) -> Dict[str, float]:
    """
    Compute evaluation metrics from detailed results.

    Args:
        results: List of evaluation results with 'correct' and 'category' keys.

    Returns:
        Dictionary of metric names to values.
    """
    categories = set(r.get("category", "factual") for r in results)

    metrics = {}
    total_correct = 0
    total = 0

    for category in categories:
        category_results = [r for r in results if r.get("category") == category]
        correct = sum(1 for r in category_results if r.get("correct", False))
        total_cat = len(category_results)

        metrics[f"{category}_accuracy"] = correct / total_cat if total_cat > 0 else 0.0
        metrics[f"{category}_total"] = total_cat
        metrics[f"{category}_correct"] = correct

        total_correct += correct
        total += total_cat

    metrics["overall_accuracy"] = total_correct / total if total > 0 else 0.0
    metrics["total_questions"] = total
    metrics["total_correct"] = total_correct

    return metrics


def compare_models(
    model_paths: List[str],
    test_questions_path: str,
    base_model: str = "mistralai/Mistral-7B-v0.3",
    output_path: Optional[str] = None,
) -> Dict[str, EvaluationResult]:
    """
    Compare multiple models on the same test set.

    Args:
        model_paths: List of paths to fine-tuned models.
        test_questions_path: Path to test questions JSON.
        base_model: Base model name.
        output_path: Optional path to save comparison results.

    Returns:
        Dictionary mapping model paths to their evaluation results.
    """
    results = {}

    for model_path in model_paths:
        print(f"\nEvaluating {model_path}...")
        result = evaluate_model(
            model_path=model_path,
            test_questions_path=test_questions_path,
            base_model=base_model,
        )
        results[model_path] = result

        print(f"  Overall: {result.overall_accuracy:.1%}")
        print(f"  Factual: {result.factual_accuracy:.1%}")
        print(f"  Reasoning: {result.reasoning_accuracy:.1%}")
        print(f"  Calibration: {result.calibration_accuracy:.1%}")

    # Save comparison
    if output_path:
        comparison_data = {
            path: {
                "overall_accuracy": r.overall_accuracy,
                "factual_accuracy": r.factual_accuracy,
                "reasoning_accuracy": r.reasoning_accuracy,
                "calibration_accuracy": r.calibration_accuracy,
            }
            for path, r in results.items()
        }

        with open(output_path, "w") as f:
            json.dump(comparison_data, f, indent=2)

        print(f"\nComparison saved to {output_path}")

    return results