""" Knowledge base for factual accuracy validation. Provides 10 factual questions spanning geography, science, math, and history for post-training validation of model knowledge retention. """ __all__ = [ "KNOWLEDGE_BASE", "validate_knowledge", "save_report", ] # Standard library import json import logging from pathlib import Path from typing import Any logger = logging.getLogger(__name__) # 10-question knowledge base (optimized from 20) KNOWLEDGE_BASE = [ # Geography + Science (5 questions) {"q": "What is the capital of France?", "a": ["Paris"]}, {"q": "What gas do plants produce?", "a": ["oxygen", "O2"]}, {"q": "How many planets in solar system?", "a": ["8", "eight"]}, {"q": "What is H2O?", "a": ["water"]}, {"q": "What orbits the Earth?", "a": ["moon", "the moon"]}, # Math + History (5 questions) {"q": "What is 2 + 2?", "a": ["4", "four"]}, {"q": "What is 10 * 10?", "a": ["100", "one hundred"]}, {"q": "Who was the first US president?", "a": ["George Washington", "Washington"]}, {"q": "What year did World War 2 end?", "a": ["1945"]}, {"q": "Who wrote Romeo and Juliet?", "a": ["Shakespeare", "William Shakespeare"]}, ] def validate_knowledge( model: Any, questions: list[dict[str, Any]], max_tokens: int = 15 ) -> dict[str, Any]: """ Validate factual accuracy (simple function approach). Args: model: Model to test questions: List of {"q": str, "a": list[str]} question/answer pairs max_tokens: Max generation length Returns: { "accuracy": float, "correct": int, "total": int, "failed": list[dict] } """ correct = 0 failed = [] for item in questions: try: output = model.generate_text(item['q'], max_length=max_tokens) is_correct = any(ans.lower() in output.lower() for ans in item['a']) if is_correct: correct += 1 else: failed.append({ 'q': item['q'], 'expected': item['a'], 'got': output[:50] # Truncate for logging }) except Exception as e: logger.warning( "Knowledge validation error", extra={"question": item['q'], "error": str(e)} ) failed.append({ 'q': item['q'], 'expected': item['a'], 'got': f"ERROR: {str(e)}" }) return { 'accuracy': correct / len(questions) if questions else 0.0, 'correct': correct, 'total': len(questions), 'failed': failed } def save_report(results: dict, output_dir: Path): """ Save validation report with path traversal protection. Args: results: Validation results dictionary output_dir: Directory to save report Raises: ValueError: If path traversal detected or directory invalid """ # Validate output_dir output_dir = Path(output_dir).resolve() if not output_dir.exists(): raise ValueError(f"Output directory does not exist: {output_dir}") # Prevent path traversal report_path = (output_dir / "quality_report.json").resolve() if not str(report_path).startswith(str(output_dir)): raise ValueError(f"Path traversal attempt detected: {report_path}") # Write atomically (temp file + rename) temp_path = report_path.with_suffix('.tmp') try: with open(temp_path, 'w') as f: json.dump(results, f, indent=2) temp_path.replace(report_path) # Atomic rename logger.info( "Report saved", extra={"path": str(report_path)} ) except Exception as e: if temp_path.exists(): temp_path.unlink() raise RuntimeError(f"Failed to save report: {e}")