File size: 6,099 Bytes
2d7e335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python3
"""
AAM Diffusion LLM — Evaluation Script

Evaluates a trained AAM Diffusion Model on test data or
generates sample narratives from graph conditioning.

Usage:
    # Evaluate on test data
    python scripts/evaluate.py --checkpoint output/best.pt

    # Generate sample narratives
    python scripts/evaluate.py --checkpoint output/best.pt --generate

    # Interactive mode
    python scripts/evaluate.py --checkpoint output/best.pt --interactive
"""

from __future__ import annotations

import argparse
import json
import logging
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from diffusion_llm.config.model_config import AamDiffusionConfig
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
from diffusion_llm.inference.generator import AamGenerator

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Evaluate AAM Diffusion LLM")
    parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path")
    parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path")
    parser.add_argument("--generate", action="store_true", help="Generate sample narratives")
    parser.add_argument("--interactive", action="store_true", help="Interactive mode")
    parser.add_argument("--test_data", type=str, default=None, help="Test data path (JSONL)")
    parser.add_argument("--n_steps", type=int, default=50, help="Inference denoising steps")
    parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
    parser.add_argument("--language", type=str, default="id", help="Output language")
    return parser.parse_args()


def generate_samples(generator: AamGenerator, language: str) -> None:
    """Generate sample narratives from predefined graph conditioning."""
    samples = [
        {
            "trigger": "Siapa yang mencuri Snow Plum Pill?",
            "evidence_nodes": ["Hefei", "Diancang Five Swords", "Ju Jangmok", "Gyeryong Merchant Guild"],
            "anomalies": ["Tidak ada konsumsi pil baru di pasar gelap", "Pencuri menghilang tanpa jejak"],
            "reasoning_steps": ["Cross-reference tanggal kejadian", "Deteksi ketidaksesuaian pola"],
        },
        {
            "trigger": "Analisis pergerakan Diancang Five Swords",
            "evidence_nodes": ["Gu Ilmu", "Jang Hangi", "Diancang Five Swords", "Hefei"],
            "anomalies": ["Success rate pair lebih tinggi dari biasanya"],
            "reasoning_steps": ["Recall laporan terkait", "Pattern completion dari bukti"],
        },
        {
            "trigger": "Hubungan antara Ju Jangmok dan pencurian",
            "evidence_nodes": ["Ju Jangmok", "Snow Plum Pill", "dark_faction"],
            "anomalies": ["Ju Jangmok menghilang hari yang sama"],
            "reasoning_steps": ["Eliminasi tersangka obvious", "Verify konsistensi"],
        },
    ]

    print("\n" + "=" * 60)
    print("  AAM Diffusion LLM — Sample Generation")
    print("=" * 60)

    for i, sample in enumerate(samples, 1):
        result = generator.generate(
            trigger=sample["trigger"],
            evidence_nodes=sample["evidence_nodes"],
            anomalies=sample["anomalies"],
            reasoning_steps=sample["reasoning_steps"],
            language=language,
        )
        print(f"\n--- Sample {i} ---")
        print(f"Trigger: {sample['trigger']}")
        print(f"Evidence: {', '.join(sample['evidence_nodes'])}")
        print(f"Anomalies: {'; '.join(sample['anomalies'])}")
        print(f"\nGenerated Narrative:")
        print(result.narrative)
        print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s]")


def interactive_mode(generator: AamGenerator, language: str) -> None:
    """Interactive generation mode."""
    print("\n" + "=" * 60)
    print("  AAM Diffusion LLM — Interactive Mode")
    print("  Type 'quit' to exit")
    print("=" * 60)

    while True:
        trigger = input("\nTrigger/Question: ").strip()
        if trigger.lower() in ("quit", "exit", "q"):
            break

        evidence = input("Evidence nodes (comma-separated): ").strip()
        evidence_nodes = [e.strip() for e in evidence.split(",") if e.strip()] if evidence else None

        anomalies_input = input("Anomalies (comma-separated): ").strip()
        anomalies = [a.strip() for a in anomalies_input.split(",") if a.strip()] if anomalies_input else None

        result = generator.generate(
            trigger=trigger,
            evidence_nodes=evidence_nodes,
            anomalies=anomalies,
            language=language,
        )
        print(f"\nGenerated Narrative:\n{result.narrative}")
        print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s, Confidence: {result.confidence:.1%}]")


def main() -> None:
    args = parse_args()

    # Load model
    logger.info("Loading model from %s", args.checkpoint)
    model = AamDiffusionModel.load(args.checkpoint)

    # Load or create tokenizer
    if args.tokenizer:
        tokenizer = AamTokenizer.load(args.tokenizer)
    else:
        # Try to find tokenizer in same directory as checkpoint
        tokenizer_path = Path(args.checkpoint).parent / "data" / "tokenizer.json"
        if tokenizer_path.exists():
            tokenizer = AamTokenizer.load(tokenizer_path)
        else:
            logger.warning("No tokenizer found. Using untrained tokenizer.")
            tokenizer = AamTokenizer()

    # Create generator
    generator = AamGenerator(model, tokenizer, model.config)

    if args.interactive:
        interactive_mode(generator, args.language)
    elif args.generate:
        generate_samples(generator, args.language)
    else:
        logger.info("Use --generate or --interactive flag")


if __name__ == "__main__":
    main()