sem-v6-training / src /sem_v6 /validation /knowledge_base.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
"""
Knowledge base for factual accuracy validation.
Provides 10 factual questions spanning geography, science, math, and history
for post-training validation of model knowledge retention.
"""
__all__ = [
"KNOWLEDGE_BASE",
"validate_knowledge",
"save_report",
]
# Standard library
import json
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# 10-question knowledge base (optimized from 20)
KNOWLEDGE_BASE = [
# Geography + Science (5 questions)
{"q": "What is the capital of France?", "a": ["Paris"]},
{"q": "What gas do plants produce?", "a": ["oxygen", "O2"]},
{"q": "How many planets in solar system?", "a": ["8", "eight"]},
{"q": "What is H2O?", "a": ["water"]},
{"q": "What orbits the Earth?", "a": ["moon", "the moon"]},
# Math + History (5 questions)
{"q": "What is 2 + 2?", "a": ["4", "four"]},
{"q": "What is 10 * 10?", "a": ["100", "one hundred"]},
{"q": "Who was the first US president?", "a": ["George Washington", "Washington"]},
{"q": "What year did World War 2 end?", "a": ["1945"]},
{"q": "Who wrote Romeo and Juliet?", "a": ["Shakespeare", "William Shakespeare"]},
]
def validate_knowledge(
model: Any,
questions: list[dict[str, Any]],
max_tokens: int = 15
) -> dict[str, Any]:
"""
Validate factual accuracy (simple function approach).
Args:
model: Model to test
questions: List of {"q": str, "a": list[str]} question/answer pairs
max_tokens: Max generation length
Returns:
{
"accuracy": float,
"correct": int,
"total": int,
"failed": list[dict]
}
"""
correct = 0
failed = []
for item in questions:
try:
output = model.generate_text(item['q'], max_length=max_tokens)
is_correct = any(ans.lower() in output.lower() for ans in item['a'])
if is_correct:
correct += 1
else:
failed.append({
'q': item['q'],
'expected': item['a'],
'got': output[:50] # Truncate for logging
})
except Exception as e:
logger.warning(
"Knowledge validation error",
extra={"question": item['q'], "error": str(e)}
)
failed.append({
'q': item['q'],
'expected': item['a'],
'got': f"ERROR: {str(e)}"
})
return {
'accuracy': correct / len(questions) if questions else 0.0,
'correct': correct,
'total': len(questions),
'failed': failed
}
def save_report(results: dict, output_dir: Path):
"""
Save validation report with path traversal protection.
Args:
results: Validation results dictionary
output_dir: Directory to save report
Raises:
ValueError: If path traversal detected or directory invalid
"""
# Validate output_dir
output_dir = Path(output_dir).resolve()
if not output_dir.exists():
raise ValueError(f"Output directory does not exist: {output_dir}")
# Prevent path traversal
report_path = (output_dir / "quality_report.json").resolve()
if not str(report_path).startswith(str(output_dir)):
raise ValueError(f"Path traversal attempt detected: {report_path}")
# Write atomically (temp file + rename)
temp_path = report_path.with_suffix('.tmp')
try:
with open(temp_path, 'w') as f:
json.dump(results, f, indent=2)
temp_path.replace(report_path) # Atomic rename
logger.info(
"Report saved",
extra={"path": str(report_path)}
)
except Exception as e:
if temp_path.exists():
temp_path.unlink()
raise RuntimeError(f"Failed to save report: {e}")