| import json | |
| from pathlib import Path | |
| from data_io import load_clauses, load_pairs | |
| def main(): | |
| clauses_path = "data/clauses_constitution_ru_kz.jsonl" | |
| train_path = "data/legal_assistant_train.jsonl" | |
| test_path = "data/legal_assistant_test.jsonl" | |
| ru, kz = load_clauses(clauses_path) | |
| corpus_ids_ru = set(x["id"] for x in ru) | |
| corpus_ids_kz = set(x["id"] for x in kz) | |
| train = load_pairs(train_path) | |
| test = load_pairs(test_path) | |
| train_ids = set(x["positive_id"] for x in train) | |
| test_ids = set(x["positive_id"] for x in test) | |
| leakage = sorted(list(train_ids.intersection(test_ids))) | |
| missing_train = [] | |
| for x in train: | |
| if x["lang"] == "ru" and x["positive_id"] not in corpus_ids_ru: | |
| missing_train.append(x["positive_id"]) | |
| if x["lang"] == "kz" and x["positive_id"] not in corpus_ids_kz: | |
| missing_train.append(x["positive_id"]) | |
| missing_test = [] | |
| for x in test: | |
| if x["lang"] == "ru" and x["positive_id"] not in corpus_ids_ru: | |
| missing_test.append(x["positive_id"]) | |
| if x["lang"] == "kz" and x["positive_id"] not in corpus_ids_kz: | |
| missing_test.append(x["positive_id"]) | |
| report = { | |
| "clauses_ru": len(ru), | |
| "clauses_kz": len(kz), | |
| "train_pairs": len(train), | |
| "test_pairs": len(test), | |
| "leakage_ids_count": len(leakage), | |
| "missing_train_positive_ids_count": len(missing_train), | |
| "missing_test_positive_ids_count": len(missing_test), | |
| "leakage_sample": leakage[:20], | |
| "missing_train_sample": missing_train[:20], | |
| "missing_test_sample": missing_test[:20], | |
| } | |
| out = Path("artifacts/reports/data_validation.json") | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") | |
| if report["leakage_ids_count"] != 0 or report["missing_train_positive_ids_count"] != 0 or report["missing_test_positive_ids_count"] != 0: | |
| raise SystemExit("DATA_VALIDATION_FAILED") | |
| if __name__ == "__main__": | |
| main() | |
| print("DATA_VALIDATION_PASSED") | |