File size: 3,927 Bytes
518db7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Knowledge base for factual accuracy validation.

Provides 10 factual questions spanning geography, science, math, and history
for post-training validation of model knowledge retention.
"""

__all__ = [
    "KNOWLEDGE_BASE",
    "validate_knowledge",
    "save_report",
]

# Standard library
import json
import logging
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)

# 10-question knowledge base (optimized from 20)
KNOWLEDGE_BASE = [
    # Geography + Science (5 questions)
    {"q": "What is the capital of France?", "a": ["Paris"]},
    {"q": "What gas do plants produce?", "a": ["oxygen", "O2"]},
    {"q": "How many planets in solar system?", "a": ["8", "eight"]},
    {"q": "What is H2O?", "a": ["water"]},
    {"q": "What orbits the Earth?", "a": ["moon", "the moon"]},

    # Math + History (5 questions)
    {"q": "What is 2 + 2?", "a": ["4", "four"]},
    {"q": "What is 10 * 10?", "a": ["100", "one hundred"]},
    {"q": "Who was the first US president?", "a": ["George Washington", "Washington"]},
    {"q": "What year did World War 2 end?", "a": ["1945"]},
    {"q": "Who wrote Romeo and Juliet?", "a": ["Shakespeare", "William Shakespeare"]},
]

def validate_knowledge(
    model: Any,
    questions: list[dict[str, Any]],
    max_tokens: int = 15
) -> dict[str, Any]:
    """
    Validate factual accuracy (simple function approach).

    Args:
        model: Model to test
        questions: List of {"q": str, "a": list[str]} question/answer pairs
        max_tokens: Max generation length

    Returns:
        {
            "accuracy": float,
            "correct": int,
            "total": int,
            "failed": list[dict]
        }
    """
    correct = 0
    failed = []

    for item in questions:
        try:
            output = model.generate_text(item['q'], max_length=max_tokens)
            is_correct = any(ans.lower() in output.lower() for ans in item['a'])

            if is_correct:
                correct += 1
            else:
                failed.append({
                    'q': item['q'],
                    'expected': item['a'],
                    'got': output[:50]  # Truncate for logging
                })

        except Exception as e:
            logger.warning(
                "Knowledge validation error",
                extra={"question": item['q'], "error": str(e)}
            )
            failed.append({
                'q': item['q'],
                'expected': item['a'],
                'got': f"ERROR: {str(e)}"
            })

    return {
        'accuracy': correct / len(questions) if questions else 0.0,
        'correct': correct,
        'total': len(questions),
        'failed': failed
    }

def save_report(results: dict, output_dir: Path):
    """
    Save validation report with path traversal protection.

    Args:
        results: Validation results dictionary
        output_dir: Directory to save report

    Raises:
        ValueError: If path traversal detected or directory invalid
    """
    # Validate output_dir
    output_dir = Path(output_dir).resolve()
    if not output_dir.exists():
        raise ValueError(f"Output directory does not exist: {output_dir}")

    # Prevent path traversal
    report_path = (output_dir / "quality_report.json").resolve()
    if not str(report_path).startswith(str(output_dir)):
        raise ValueError(f"Path traversal attempt detected: {report_path}")

    # Write atomically (temp file + rename)
    temp_path = report_path.with_suffix('.tmp')
    try:
        with open(temp_path, 'w') as f:
            json.dump(results, f, indent=2)
        temp_path.replace(report_path)  # Atomic rename
        logger.info(
            "Report saved",
            extra={"path": str(report_path)}
        )
    except Exception as e:
        if temp_path.exists():
            temp_path.unlink()
        raise RuntimeError(f"Failed to save report: {e}")