File size: 4,330 Bytes
c7ebaa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python3
"""
BioRLHF Model Evaluation Example

This script demonstrates how to evaluate a fine-tuned model on
biological reasoning tasks.

Usage:
    python evaluate_model.py --model ./biorlhf_model --test-set kmp_test_set.json
"""

import argparse
import json
from pathlib import Path

from biorlhf import evaluate_model


def main():
    """Run model evaluation."""
    parser = argparse.ArgumentParser(
        description="Evaluate a fine-tuned BioRLHF model"
    )
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Path to the fine-tuned model directory",
    )
    parser.add_argument(
        "--test-set",
        type=str,
        default="kmp_test_set.json",
        help="Path to test questions JSON file",
    )
    parser.add_argument(
        "--base-model",
        type=str,
        default="mistralai/Mistral-7B-v0.3",
        help="Base model name",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output path for detailed results JSON",
    )
    parser.add_argument(
        "--no-quantization",
        action="store_true",
        help="Disable 4-bit quantization",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.1,
        help="Generation temperature (0 for greedy)",
    )
    parser.add_argument(
        "--max-tokens",
        type=int,
        default=512,
        help="Maximum tokens to generate",
    )

    args = parser.parse_args()

    # Check if test set exists
    if not Path(args.test_set).exists():
        print(f"Error: Test set not found at {args.test_set}")
        print("\nYou can create a test set or use the default one from the data folder.")
        return

    print("=" * 60)
    print("BioRLHF Model Evaluation")
    print("=" * 60)
    print(f"Model: {args.model}")
    print(f"Base Model: {args.base_model}")
    print(f"Test Set: {args.test_set}")
    print(f"Quantization: {'Disabled' if args.no_quantization else '4-bit'}")
    print("=" * 60)

    # Run evaluation
    results = evaluate_model(
        model_path=args.model,
        test_questions_path=args.test_set,
        base_model=args.base_model,
        use_4bit=not args.no_quantization,
        max_new_tokens=args.max_tokens,
        temperature=args.temperature,
    )

    # Print results
    print("\n" + "=" * 60)
    print("EVALUATION RESULTS")
    print("=" * 60)
    print(f"\nOverall Accuracy: {results.overall_accuracy:.1%} ({results.correct_answers}/{results.total_questions})")
    print(f"\nBy Category:")
    print(f"  Factual:     {results.factual_accuracy:.1%}")
    print(f"  Reasoning:   {results.reasoning_accuracy:.1%}")
    print(f"  Calibration: {results.calibration_accuracy:.1%}")

    # Show detailed results
    print("\n" + "-" * 60)
    print("Detailed Results:")
    print("-" * 60)

    for i, r in enumerate(results.detailed_results, 1):
        status = "CORRECT" if r["correct"] else "WRONG"
        print(f"\n{i}. [{r['category'].upper()}] {status}")
        print(f"   Q: {r['question'][:80]}...")
        print(f"   Expected: {r['expected'][:50]}..." if len(r["expected"]) > 50 else f"   Expected: {r['expected']}")
        print(f"   Response: {r['response'][:100]}..." if len(r["response"]) > 100 else f"   Response: {r['response']}")

    # Save detailed results if requested
    if args.output:
        output_data = {
            "model_path": args.model,
            "base_model": args.base_model,
            "test_set": args.test_set,
            "metrics": {
                "overall_accuracy": results.overall_accuracy,
                "factual_accuracy": results.factual_accuracy,
                "reasoning_accuracy": results.reasoning_accuracy,
                "calibration_accuracy": results.calibration_accuracy,
                "total_questions": results.total_questions,
                "correct_answers": results.correct_answers,
            },
            "detailed_results": results.detailed_results,
        }

        with open(args.output, "w") as f:
            json.dump(output_data, f, indent=2)

        print(f"\nDetailed results saved to: {args.output}")

    print("\n" + "=" * 60)
    print("Evaluation complete!")
    print("=" * 60)


if __name__ == "__main__":
    main()