File size: 2,184 Bytes
6a02b16
 
c6cece9
6a02b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import json
from pathlib import Path
from data_io import load_clauses, load_pairs

def main():
    clauses_path = "data/clauses_constitution_ru_kz.jsonl"
    train_path = "data/legal_assistant_train.jsonl"
    test_path = "data/legal_assistant_test.jsonl"

    ru, kz = load_clauses(clauses_path)
    corpus_ids_ru = set(x["id"] for x in ru)
    corpus_ids_kz = set(x["id"] for x in kz)

    train = load_pairs(train_path)
    test = load_pairs(test_path)

    train_ids = set(x["positive_id"] for x in train)
    test_ids = set(x["positive_id"] for x in test)

    leakage = sorted(list(train_ids.intersection(test_ids)))

    missing_train = []
    for x in train:
        if x["lang"] == "ru" and x["positive_id"] not in corpus_ids_ru:
            missing_train.append(x["positive_id"])
        if x["lang"] == "kz" and x["positive_id"] not in corpus_ids_kz:
            missing_train.append(x["positive_id"])

    missing_test = []
    for x in test:
        if x["lang"] == "ru" and x["positive_id"] not in corpus_ids_ru:
            missing_test.append(x["positive_id"])
        if x["lang"] == "kz" and x["positive_id"] not in corpus_ids_kz:
            missing_test.append(x["positive_id"])

    report = {
        "clauses_ru": len(ru),
        "clauses_kz": len(kz),
        "train_pairs": len(train),
        "test_pairs": len(test),
        "leakage_ids_count": len(leakage),
        "missing_train_positive_ids_count": len(missing_train),
        "missing_test_positive_ids_count": len(missing_test),
        "leakage_sample": leakage[:20],
        "missing_train_sample": missing_train[:20],
        "missing_test_sample": missing_test[:20],
    }

    out = Path("artifacts/reports/data_validation.json")
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")

    if report["leakage_ids_count"] != 0 or report["missing_train_positive_ids_count"] != 0 or report["missing_test_positive_ids_count"] != 0:
        raise SystemExit("DATA_VALIDATION_FAILED")

if __name__ == "__main__":
    main()
    print("DATA_VALIDATION_PASSED")