File size: 3,927 Bytes
518db7a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """
Knowledge base for factual accuracy validation.
Provides 10 factual questions spanning geography, science, math, and history
for post-training validation of model knowledge retention.
"""
__all__ = [
"KNOWLEDGE_BASE",
"validate_knowledge",
"save_report",
]
# Standard library
import json
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# 10-question knowledge base (optimized from 20)
KNOWLEDGE_BASE = [
# Geography + Science (5 questions)
{"q": "What is the capital of France?", "a": ["Paris"]},
{"q": "What gas do plants produce?", "a": ["oxygen", "O2"]},
{"q": "How many planets in solar system?", "a": ["8", "eight"]},
{"q": "What is H2O?", "a": ["water"]},
{"q": "What orbits the Earth?", "a": ["moon", "the moon"]},
# Math + History (5 questions)
{"q": "What is 2 + 2?", "a": ["4", "four"]},
{"q": "What is 10 * 10?", "a": ["100", "one hundred"]},
{"q": "Who was the first US president?", "a": ["George Washington", "Washington"]},
{"q": "What year did World War 2 end?", "a": ["1945"]},
{"q": "Who wrote Romeo and Juliet?", "a": ["Shakespeare", "William Shakespeare"]},
]
def validate_knowledge(
model: Any,
questions: list[dict[str, Any]],
max_tokens: int = 15
) -> dict[str, Any]:
"""
Validate factual accuracy (simple function approach).
Args:
model: Model to test
questions: List of {"q": str, "a": list[str]} question/answer pairs
max_tokens: Max generation length
Returns:
{
"accuracy": float,
"correct": int,
"total": int,
"failed": list[dict]
}
"""
correct = 0
failed = []
for item in questions:
try:
output = model.generate_text(item['q'], max_length=max_tokens)
is_correct = any(ans.lower() in output.lower() for ans in item['a'])
if is_correct:
correct += 1
else:
failed.append({
'q': item['q'],
'expected': item['a'],
'got': output[:50] # Truncate for logging
})
except Exception as e:
logger.warning(
"Knowledge validation error",
extra={"question": item['q'], "error": str(e)}
)
failed.append({
'q': item['q'],
'expected': item['a'],
'got': f"ERROR: {str(e)}"
})
return {
'accuracy': correct / len(questions) if questions else 0.0,
'correct': correct,
'total': len(questions),
'failed': failed
}
def save_report(results: dict, output_dir: Path):
"""
Save validation report with path traversal protection.
Args:
results: Validation results dictionary
output_dir: Directory to save report
Raises:
ValueError: If path traversal detected or directory invalid
"""
# Validate output_dir
output_dir = Path(output_dir).resolve()
if not output_dir.exists():
raise ValueError(f"Output directory does not exist: {output_dir}")
# Prevent path traversal
report_path = (output_dir / "quality_report.json").resolve()
if not str(report_path).startswith(str(output_dir)):
raise ValueError(f"Path traversal attempt detected: {report_path}")
# Write atomically (temp file + rename)
temp_path = report_path.with_suffix('.tmp')
try:
with open(temp_path, 'w') as f:
json.dump(results, f, indent=2)
temp_path.replace(report_path) # Atomic rename
logger.info(
"Report saved",
extra={"path": str(report_path)}
)
except Exception as e:
if temp_path.exists():
temp_path.unlink()
raise RuntimeError(f"Failed to save report: {e}")
|