File size: 6,099 Bytes
2d7e335 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | #!/usr/bin/env python3
"""
AAM Diffusion LLM — Evaluation Script
Evaluates a trained AAM Diffusion Model on test data or
generates sample narratives from graph conditioning.
Usage:
# Evaluate on test data
python scripts/evaluate.py --checkpoint output/best.pt
# Generate sample narratives
python scripts/evaluate.py --checkpoint output/best.pt --generate
# Interactive mode
python scripts/evaluate.py --checkpoint output/best.pt --interactive
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from diffusion_llm.config.model_config import AamDiffusionConfig
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
from diffusion_llm.inference.generator import AamGenerator
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate AAM Diffusion LLM")
parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path")
parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path")
parser.add_argument("--generate", action="store_true", help="Generate sample narratives")
parser.add_argument("--interactive", action="store_true", help="Interactive mode")
parser.add_argument("--test_data", type=str, default=None, help="Test data path (JSONL)")
parser.add_argument("--n_steps", type=int, default=50, help="Inference denoising steps")
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
parser.add_argument("--language", type=str, default="id", help="Output language")
return parser.parse_args()
def generate_samples(generator: AamGenerator, language: str) -> None:
"""Generate sample narratives from predefined graph conditioning."""
samples = [
{
"trigger": "Siapa yang mencuri Snow Plum Pill?",
"evidence_nodes": ["Hefei", "Diancang Five Swords", "Ju Jangmok", "Gyeryong Merchant Guild"],
"anomalies": ["Tidak ada konsumsi pil baru di pasar gelap", "Pencuri menghilang tanpa jejak"],
"reasoning_steps": ["Cross-reference tanggal kejadian", "Deteksi ketidaksesuaian pola"],
},
{
"trigger": "Analisis pergerakan Diancang Five Swords",
"evidence_nodes": ["Gu Ilmu", "Jang Hangi", "Diancang Five Swords", "Hefei"],
"anomalies": ["Success rate pair lebih tinggi dari biasanya"],
"reasoning_steps": ["Recall laporan terkait", "Pattern completion dari bukti"],
},
{
"trigger": "Hubungan antara Ju Jangmok dan pencurian",
"evidence_nodes": ["Ju Jangmok", "Snow Plum Pill", "dark_faction"],
"anomalies": ["Ju Jangmok menghilang hari yang sama"],
"reasoning_steps": ["Eliminasi tersangka obvious", "Verify konsistensi"],
},
]
print("\n" + "=" * 60)
print(" AAM Diffusion LLM — Sample Generation")
print("=" * 60)
for i, sample in enumerate(samples, 1):
result = generator.generate(
trigger=sample["trigger"],
evidence_nodes=sample["evidence_nodes"],
anomalies=sample["anomalies"],
reasoning_steps=sample["reasoning_steps"],
language=language,
)
print(f"\n--- Sample {i} ---")
print(f"Trigger: {sample['trigger']}")
print(f"Evidence: {', '.join(sample['evidence_nodes'])}")
print(f"Anomalies: {'; '.join(sample['anomalies'])}")
print(f"\nGenerated Narrative:")
print(result.narrative)
print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s]")
def interactive_mode(generator: AamGenerator, language: str) -> None:
"""Interactive generation mode."""
print("\n" + "=" * 60)
print(" AAM Diffusion LLM — Interactive Mode")
print(" Type 'quit' to exit")
print("=" * 60)
while True:
trigger = input("\nTrigger/Question: ").strip()
if trigger.lower() in ("quit", "exit", "q"):
break
evidence = input("Evidence nodes (comma-separated): ").strip()
evidence_nodes = [e.strip() for e in evidence.split(",") if e.strip()] if evidence else None
anomalies_input = input("Anomalies (comma-separated): ").strip()
anomalies = [a.strip() for a in anomalies_input.split(",") if a.strip()] if anomalies_input else None
result = generator.generate(
trigger=trigger,
evidence_nodes=evidence_nodes,
anomalies=anomalies,
language=language,
)
print(f"\nGenerated Narrative:\n{result.narrative}")
print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s, Confidence: {result.confidence:.1%}]")
def main() -> None:
args = parse_args()
# Load model
logger.info("Loading model from %s", args.checkpoint)
model = AamDiffusionModel.load(args.checkpoint)
# Load or create tokenizer
if args.tokenizer:
tokenizer = AamTokenizer.load(args.tokenizer)
else:
# Try to find tokenizer in same directory as checkpoint
tokenizer_path = Path(args.checkpoint).parent / "data" / "tokenizer.json"
if tokenizer_path.exists():
tokenizer = AamTokenizer.load(tokenizer_path)
else:
logger.warning("No tokenizer found. Using untrained tokenizer.")
tokenizer = AamTokenizer()
# Create generator
generator = AamGenerator(model, tokenizer, model.config)
if args.interactive:
interactive_mode(generator, args.language)
elif args.generate:
generate_samples(generator, args.language)
else:
logger.info("Use --generate or --interactive flag")
if __name__ == "__main__":
main()
|