File size: 3,710 Bytes
aa988a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

import torch

_METRICS_DIR = Path(__file__).resolve().parent
_SCRIPTS_DIR = _METRICS_DIR.parent
for path in (_METRICS_DIR, _SCRIPTS_DIR):
    if str(path) not in sys.path:
        sys.path.insert(0, str(path))

from broken_code_generation import (  # noqa: E402
    ADAPTER_DIR,
    DEFAULT_EVAL_LIMIT,
    EVAL_FILE,
    FILE_JSON_VALIDITY,
    GEN_MAX_NEW_TOKENS,
    GEN_SEED,
    GEN_TEMPERATURE,
    GEN_TOP_P,
    MODEL_ID,
)
from evaluate_model import REQUIRED_FIELDS, generate_one, load_model_and_tokenizer  # noqa: E402
from report_io import metrics_path, write_report  # noqa: E402


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=f"JSON validity for {MODEL_ID} only (adapter at {ADAPTER_DIR})."
    )
    parser.add_argument("--limit", type=int, default=DEFAULT_EVAL_LIMIT)
    parser.add_argument("--output", type=Path, default=None)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    torch.manual_seed(GEN_SEED)

    if not ADAPTER_DIR.exists():
        raise FileNotFoundError(f"Adapter not found: {ADAPTER_DIR}")

    records = json.loads(EVAL_FILE.read_text(encoding="utf-8"))[: args.limit]
    print(f"Model: {MODEL_ID}")
    print(f"Adapter: {ADAPTER_DIR}")
    print(f"Samples: {len(records)} from {EVAL_FILE}")

    model, tokenizer = load_model_and_tokenizer(ADAPTER_DIR)
    model.eval()

    valid_json = required = difficulty_ok = tags_ok = 0
    results = []

    for index, record in enumerate(records, start=1):
        row = {"index": index, "status": "error"}
        try:
            generated = generate_one(
                model=model,
                tokenizer=tokenizer,
                topic_tags=record["topic_tags"],
                difficulty=record["difficulty"],
                max_new_tokens=GEN_MAX_NEW_TOKENS,
                temperature=GEN_TEMPERATURE,
                top_p=GEN_TOP_P,
            )
            valid_json += 1
            row["status"] = "ok"
            row["generated"] = generated
            if REQUIRED_FIELDS.issubset(generated):
                required += 1
            if generated.get("difficulty") == record["difficulty"]:
                difficulty_ok += 1
            if set(generated.get("topic_tags", {})) == set(record["topic_tags"]):
                tags_ok += 1
        except Exception as error:  # noqa: BLE001
            row["error"] = str(error)
        results.append(row)
        print(f"[{MODEL_ID}] {index}/{len(records)} valid_json={valid_json}", flush=True)

    n = max(len(records), 1)
    report = {
        "metric_group": "json_validity",
        "model": MODEL_ID,
        "adapter_dir": str(ADAPTER_DIR),
        "evaluation_file": str(EVAL_FILE),
        "samples_evaluated": len(records),
        "generation": {
            "temperature": GEN_TEMPERATURE,
            "top_p": GEN_TOP_P,
            "max_new_tokens": GEN_MAX_NEW_TOKENS,
            "seed": GEN_SEED,
        },
        "metrics": {
            "valid_json_rate": round(valid_json / n, 4),
            "required_fields_rate": round(required / n, 4),
            "difficulty_match_rate": round(difficulty_ok / n, 4),
            "topic_tag_key_match_rate": round(tags_ok / n, 4),
        },
        "metrics_counts": {
            "valid_json": valid_json,
            "required_fields_complete": required,
            "difficulty_match": difficulty_ok,
            "topic_tag_keys_match": tags_ok,
        },
        "results": results,
    }
    write_report(args.output or metrics_path(FILE_JSON_VALIDITY), report)


if __name__ == "__main__":
    main()