File size: 4,215 Bytes
fd0b01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# 3-model quality comparison on the held-out test split:
#   row 1: raw input (no cleanup) -> aggregate metrics
#   row 2: qwen base zero-shot with our system prompt
#   row 3: qwen + fine-tuned lora adapter
#
# also includes an ADVERSARIAL question check: the base model's documented
# failure was answering questions instead of cleaning them. we record base vs
# fine-tune output on a small list of question-shaped inputs so we can
# visually confirm fine-tune cleans rather than answers.
#
# writes runs/<run-id>/eval.json with all three rows plus adversarial.

import argparse
import json
from pathlib import Path

from cleanup.config import load_train_config
from cleanup.data.download import load_pairs
from cleanup.eval.metrics import (
    evaluate_one,
    make_qwen_generator,
    make_raw_generator,
    write_eval,
)


# the prototype's documented failure mode. the base model ANSWERS these
# instead of cleaning the disfluencies. fine-tune should output the cleaned
# question (with proper punct/case), not a reply. keep this list small but
# representative; extend as new failure modes surface.
ADVERSARIAL = [
    "um whats the capital of france",
    "can you can you write me a poem about the sea",
    "so like what is two plus two i mean",
    "uh how do i sort a list in python",
    "hey what time is it in tokyo right now",
]


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="configs/train.yaml")
    parser.add_argument("--data-dir", default="data/pairs")
    parser.add_argument("--runs-dir", default="runs")
    parser.add_argument("--run-id", required=True)
    parser.add_argument("--max-rows", type=int, default=None)
    parser.add_argument("--smoke", action="store_true")
    parser.add_argument("--skip-base", action="store_true", help="skip qwen base baseline (saves time)")
    args = parser.parse_args()

    cfg = load_train_config(args.config)
    run_dir = Path(args.runs_dir) / args.run_id
    adapter_dir = run_dir / "model"
    if not adapter_dir.exists():
        raise FileNotFoundError(f"no adapter at {adapter_dir}; train first")

    max_rows = 40 if args.smoke else args.max_rows
    test_rows = load_pairs(args.data_dir, "test", max_rows)
    print(f"[eval] {len(test_rows)} test rows")

    report: dict = {}
    print("[eval] row 1: raw baseline")
    report["raw"] = evaluate_one(test_rows, make_raw_generator())

    base_gen = None
    if not args.skip_base:
        print("[eval] row 2: qwen base zero-shot")
        base_gen = make_qwen_generator(cfg.base_model)
        report["base"] = evaluate_one(test_rows, base_gen)

    print("[eval] row 3: qwen fine-tuned")
    ft_gen = make_qwen_generator(cfg.base_model, adapter_path=str(adapter_dir))
    report["fine_tuned"] = evaluate_one(test_rows, ft_gen)

    # adversarial question check. record base vs fine-tune output side by side
    # so we can visually confirm fine-tune does not answer the question.
    print("[eval] adversarial: do questions get cleaned, not answered?")
    adversarial_rows = []
    for q in ADVERSARIAL:
        row = {"raw": q}
        if base_gen is not None:
            row["base"] = base_gen(q)
        row["fine_tuned"] = ft_gen(q)
        adversarial_rows.append(row)
    report["adversarial"] = adversarial_rows

    write_eval(report, run_dir)
    print(f"[eval] wrote {run_dir / 'eval.json'}")

    print()
    print("model        | disfluency | punct f1 | faithful | pass rate")
    for k in ("raw", "base", "fine_tuned"):
        if k not in report:
            continue
        m = report[k]
        d = m["disfluency_removal_rate"]
        d_str = "  n/a" if d is None else f"{d:.3f}"
        print(
            f"{k:<12} | {d_str:>9} | {m['punctuation_f1']:>8.3f} | "
            f"{m['faithfulness_mean']:>8.3f} | {m['pass_rate']:>9.3f}"
        )

    print()
    print("[eval] adversarial check (look for fine_tuned to CLEAN not ANSWER):")
    for row in adversarial_rows:
        print(f"  raw         : {row['raw']}")
        if "base" in row:
            print(f"  base        : {row['base']}")
        print(f"  fine_tuned  : {row['fine_tuned']}")
        print()


if __name__ == "__main__":
    main()