File size: 2,184 Bytes
6a02b16 c6cece9 6a02b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import json
from pathlib import Path
from data_io import load_clauses, load_pairs
def main():
clauses_path = "data/clauses_constitution_ru_kz.jsonl"
train_path = "data/legal_assistant_train.jsonl"
test_path = "data/legal_assistant_test.jsonl"
ru, kz = load_clauses(clauses_path)
corpus_ids_ru = set(x["id"] for x in ru)
corpus_ids_kz = set(x["id"] for x in kz)
train = load_pairs(train_path)
test = load_pairs(test_path)
train_ids = set(x["positive_id"] for x in train)
test_ids = set(x["positive_id"] for x in test)
leakage = sorted(list(train_ids.intersection(test_ids)))
missing_train = []
for x in train:
if x["lang"] == "ru" and x["positive_id"] not in corpus_ids_ru:
missing_train.append(x["positive_id"])
if x["lang"] == "kz" and x["positive_id"] not in corpus_ids_kz:
missing_train.append(x["positive_id"])
missing_test = []
for x in test:
if x["lang"] == "ru" and x["positive_id"] not in corpus_ids_ru:
missing_test.append(x["positive_id"])
if x["lang"] == "kz" and x["positive_id"] not in corpus_ids_kz:
missing_test.append(x["positive_id"])
report = {
"clauses_ru": len(ru),
"clauses_kz": len(kz),
"train_pairs": len(train),
"test_pairs": len(test),
"leakage_ids_count": len(leakage),
"missing_train_positive_ids_count": len(missing_train),
"missing_test_positive_ids_count": len(missing_test),
"leakage_sample": leakage[:20],
"missing_train_sample": missing_train[:20],
"missing_test_sample": missing_test[:20],
}
out = Path("artifacts/reports/data_validation.json")
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
if report["leakage_ids_count"] != 0 or report["missing_train_positive_ids_count"] != 0 or report["missing_test_positive_ids_count"] != 0:
raise SystemExit("DATA_VALIDATION_FAILED")
if __name__ == "__main__":
main()
print("DATA_VALIDATION_PASSED")
|