| |
| """ |
| AAM Diffusion LLM — Evaluation Script |
| |
| Evaluates a trained AAM Diffusion Model on test data or |
| generates sample narratives from graph conditioning. |
| |
| Usage: |
| # Evaluate on test data |
| python scripts/evaluate.py --checkpoint output/best.pt |
| |
| # Generate sample narratives |
| python scripts/evaluate.py --checkpoint output/best.pt --generate |
| |
| # Interactive mode |
| python scripts/evaluate.py --checkpoint output/best.pt --interactive |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from diffusion_llm.config.model_config import AamDiffusionConfig |
| from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel |
| from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer |
| from diffusion_llm.inference.generator import AamGenerator |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Evaluate AAM Diffusion LLM") |
| parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path") |
| parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path") |
| parser.add_argument("--generate", action="store_true", help="Generate sample narratives") |
| parser.add_argument("--interactive", action="store_true", help="Interactive mode") |
| parser.add_argument("--test_data", type=str, default=None, help="Test data path (JSONL)") |
| parser.add_argument("--n_steps", type=int, default=50, help="Inference denoising steps") |
| parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") |
| parser.add_argument("--language", type=str, default="id", help="Output language") |
| return parser.parse_args() |
|
|
|
|
| def generate_samples(generator: AamGenerator, language: str) -> None: |
| """Generate sample narratives from predefined graph conditioning.""" |
| samples = [ |
| { |
| "trigger": "Siapa yang mencuri Snow Plum Pill?", |
| "evidence_nodes": ["Hefei", "Diancang Five Swords", "Ju Jangmok", "Gyeryong Merchant Guild"], |
| "anomalies": ["Tidak ada konsumsi pil baru di pasar gelap", "Pencuri menghilang tanpa jejak"], |
| "reasoning_steps": ["Cross-reference tanggal kejadian", "Deteksi ketidaksesuaian pola"], |
| }, |
| { |
| "trigger": "Analisis pergerakan Diancang Five Swords", |
| "evidence_nodes": ["Gu Ilmu", "Jang Hangi", "Diancang Five Swords", "Hefei"], |
| "anomalies": ["Success rate pair lebih tinggi dari biasanya"], |
| "reasoning_steps": ["Recall laporan terkait", "Pattern completion dari bukti"], |
| }, |
| { |
| "trigger": "Hubungan antara Ju Jangmok dan pencurian", |
| "evidence_nodes": ["Ju Jangmok", "Snow Plum Pill", "dark_faction"], |
| "anomalies": ["Ju Jangmok menghilang hari yang sama"], |
| "reasoning_steps": ["Eliminasi tersangka obvious", "Verify konsistensi"], |
| }, |
| ] |
|
|
| print("\n" + "=" * 60) |
| print(" AAM Diffusion LLM — Sample Generation") |
| print("=" * 60) |
|
|
| for i, sample in enumerate(samples, 1): |
| result = generator.generate( |
| trigger=sample["trigger"], |
| evidence_nodes=sample["evidence_nodes"], |
| anomalies=sample["anomalies"], |
| reasoning_steps=sample["reasoning_steps"], |
| language=language, |
| ) |
| print(f"\n--- Sample {i} ---") |
| print(f"Trigger: {sample['trigger']}") |
| print(f"Evidence: {', '.join(sample['evidence_nodes'])}") |
| print(f"Anomalies: {'; '.join(sample['anomalies'])}") |
| print(f"\nGenerated Narrative:") |
| print(result.narrative) |
| print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s]") |
|
|
|
|
| def interactive_mode(generator: AamGenerator, language: str) -> None: |
| """Interactive generation mode.""" |
| print("\n" + "=" * 60) |
| print(" AAM Diffusion LLM — Interactive Mode") |
| print(" Type 'quit' to exit") |
| print("=" * 60) |
|
|
| while True: |
| trigger = input("\nTrigger/Question: ").strip() |
| if trigger.lower() in ("quit", "exit", "q"): |
| break |
|
|
| evidence = input("Evidence nodes (comma-separated): ").strip() |
| evidence_nodes = [e.strip() for e in evidence.split(",") if e.strip()] if evidence else None |
|
|
| anomalies_input = input("Anomalies (comma-separated): ").strip() |
| anomalies = [a.strip() for a in anomalies_input.split(",") if a.strip()] if anomalies_input else None |
|
|
| result = generator.generate( |
| trigger=trigger, |
| evidence_nodes=evidence_nodes, |
| anomalies=anomalies, |
| language=language, |
| ) |
| print(f"\nGenerated Narrative:\n{result.narrative}") |
| print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s, Confidence: {result.confidence:.1%}]") |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| |
| logger.info("Loading model from %s", args.checkpoint) |
| model = AamDiffusionModel.load(args.checkpoint) |
|
|
| |
| if args.tokenizer: |
| tokenizer = AamTokenizer.load(args.tokenizer) |
| else: |
| |
| tokenizer_path = Path(args.checkpoint).parent / "data" / "tokenizer.json" |
| if tokenizer_path.exists(): |
| tokenizer = AamTokenizer.load(tokenizer_path) |
| else: |
| logger.warning("No tokenizer found. Using untrained tokenizer.") |
| tokenizer = AamTokenizer() |
|
|
| |
| generator = AamGenerator(model, tokenizer, model.config) |
|
|
| if args.interactive: |
| interactive_mode(generator, args.language) |
| elif args.generate: |
| generate_samples(generator, args.language) |
| else: |
| logger.info("Use --generate or --interactive flag") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|