lexir / src /validate.py
irinaqqq's picture
ADDED MORE GPAPHS
c6cece9
import json
from pathlib import Path
from data_io import load_clauses, load_pairs
def main():
clauses_path = "data/clauses_constitution_ru_kz.jsonl"
train_path = "data/legal_assistant_train.jsonl"
test_path = "data/legal_assistant_test.jsonl"
ru, kz = load_clauses(clauses_path)
corpus_ids_ru = set(x["id"] for x in ru)
corpus_ids_kz = set(x["id"] for x in kz)
train = load_pairs(train_path)
test = load_pairs(test_path)
train_ids = set(x["positive_id"] for x in train)
test_ids = set(x["positive_id"] for x in test)
leakage = sorted(list(train_ids.intersection(test_ids)))
missing_train = []
for x in train:
if x["lang"] == "ru" and x["positive_id"] not in corpus_ids_ru:
missing_train.append(x["positive_id"])
if x["lang"] == "kz" and x["positive_id"] not in corpus_ids_kz:
missing_train.append(x["positive_id"])
missing_test = []
for x in test:
if x["lang"] == "ru" and x["positive_id"] not in corpus_ids_ru:
missing_test.append(x["positive_id"])
if x["lang"] == "kz" and x["positive_id"] not in corpus_ids_kz:
missing_test.append(x["positive_id"])
report = {
"clauses_ru": len(ru),
"clauses_kz": len(kz),
"train_pairs": len(train),
"test_pairs": len(test),
"leakage_ids_count": len(leakage),
"missing_train_positive_ids_count": len(missing_train),
"missing_test_positive_ids_count": len(missing_test),
"leakage_sample": leakage[:20],
"missing_train_sample": missing_train[:20],
"missing_test_sample": missing_test[:20],
}
out = Path("artifacts/reports/data_validation.json")
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
if report["leakage_ids_count"] != 0 or report["missing_train_positive_ids_count"] != 0 or report["missing_test_positive_ids_count"] != 0:
raise SystemExit("DATA_VALIDATION_FAILED")
if __name__ == "__main__":
main()
print("DATA_VALIDATION_PASSED")