| """ |
| Knowledge base for factual accuracy validation. |
| |
| Provides 10 factual questions spanning geography, science, math, and history |
| for post-training validation of model knowledge retention. |
| """ |
|
|
| __all__ = [ |
| "KNOWLEDGE_BASE", |
| "validate_knowledge", |
| "save_report", |
| ] |
|
|
| |
| import json |
| import logging |
| from pathlib import Path |
| from typing import Any |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| KNOWLEDGE_BASE = [ |
| |
| {"q": "What is the capital of France?", "a": ["Paris"]}, |
| {"q": "What gas do plants produce?", "a": ["oxygen", "O2"]}, |
| {"q": "How many planets in solar system?", "a": ["8", "eight"]}, |
| {"q": "What is H2O?", "a": ["water"]}, |
| {"q": "What orbits the Earth?", "a": ["moon", "the moon"]}, |
|
|
| |
| {"q": "What is 2 + 2?", "a": ["4", "four"]}, |
| {"q": "What is 10 * 10?", "a": ["100", "one hundred"]}, |
| {"q": "Who was the first US president?", "a": ["George Washington", "Washington"]}, |
| {"q": "What year did World War 2 end?", "a": ["1945"]}, |
| {"q": "Who wrote Romeo and Juliet?", "a": ["Shakespeare", "William Shakespeare"]}, |
| ] |
|
|
| def validate_knowledge( |
| model: Any, |
| questions: list[dict[str, Any]], |
| max_tokens: int = 15 |
| ) -> dict[str, Any]: |
| """ |
| Validate factual accuracy (simple function approach). |
| |
| Args: |
| model: Model to test |
| questions: List of {"q": str, "a": list[str]} question/answer pairs |
| max_tokens: Max generation length |
| |
| Returns: |
| { |
| "accuracy": float, |
| "correct": int, |
| "total": int, |
| "failed": list[dict] |
| } |
| """ |
| correct = 0 |
| failed = [] |
|
|
| for item in questions: |
| try: |
| output = model.generate_text(item['q'], max_length=max_tokens) |
| is_correct = any(ans.lower() in output.lower() for ans in item['a']) |
|
|
| if is_correct: |
| correct += 1 |
| else: |
| failed.append({ |
| 'q': item['q'], |
| 'expected': item['a'], |
| 'got': output[:50] |
| }) |
|
|
| except Exception as e: |
| logger.warning( |
| "Knowledge validation error", |
| extra={"question": item['q'], "error": str(e)} |
| ) |
| failed.append({ |
| 'q': item['q'], |
| 'expected': item['a'], |
| 'got': f"ERROR: {str(e)}" |
| }) |
|
|
| return { |
| 'accuracy': correct / len(questions) if questions else 0.0, |
| 'correct': correct, |
| 'total': len(questions), |
| 'failed': failed |
| } |
|
|
| def save_report(results: dict, output_dir: Path): |
| """ |
| Save validation report with path traversal protection. |
| |
| Args: |
| results: Validation results dictionary |
| output_dir: Directory to save report |
| |
| Raises: |
| ValueError: If path traversal detected or directory invalid |
| """ |
| |
| output_dir = Path(output_dir).resolve() |
| if not output_dir.exists(): |
| raise ValueError(f"Output directory does not exist: {output_dir}") |
|
|
| |
| report_path = (output_dir / "quality_report.json").resolve() |
| if not str(report_path).startswith(str(output_dir)): |
| raise ValueError(f"Path traversal attempt detected: {report_path}") |
|
|
| |
| temp_path = report_path.with_suffix('.tmp') |
| try: |
| with open(temp_path, 'w') as f: |
| json.dump(results, f, indent=2) |
| temp_path.replace(report_path) |
| logger.info( |
| "Report saved", |
| extra={"path": str(report_path)} |
| ) |
| except Exception as e: |
| if temp_path.exists(): |
| temp_path.unlink() |
| raise RuntimeError(f"Failed to save report: {e}") |
|
|