File size: 4,543 Bytes
4937cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Data ingestion and validation utilities for the fraud dataset."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any

import pandas as pd

EXPECTED_ROW_COUNT = 284_807
EXPECTED_COLUMNS = ["Time", *[f"V{i}" for i in range(1, 29)], "Amount", "Class"]
EXPECTED_CLASS_VALUES = {0, 1}


def load_data(file_path: str | Path) -> pd.DataFrame:
    """Load CSV data from disk."""
    path = Path(file_path)
    if not path.exists():
        raise FileNotFoundError(f"Dataset not found: {path}")
    if path.suffix.lower() != ".csv":
        raise ValueError(f"Expected a CSV file, got: {path.suffix}")
    return pd.read_csv(path)


def get_data_statistics(df: pd.DataFrame) -> dict[str, Any]:
    """Return key dataset statistics used for validation and monitoring."""
    class_counts: dict[str, int] = {}
    fraud_ratio: float | None = None

    if "Class" in df.columns:
        raw_counts = df["Class"].value_counts(dropna=False).to_dict()
        class_counts = {str(k): int(v) for k, v in raw_counts.items()}
        if len(df) > 0:
            fraud_ratio = float((df["Class"] == 1).sum() / len(df))

    return {
        "row_count": int(df.shape[0]),
        "column_count": int(df.shape[1]),
        "missing_values_total": int(df.isna().sum().sum()),
        "duplicate_rows": int(df.duplicated().sum()),
        "class_counts": class_counts,
        "fraud_ratio": fraud_ratio,
    }


def validate_data(df: pd.DataFrame, expected_rows: int = EXPECTED_ROW_COUNT) -> dict[str, Any]:
    """Validate schema and data quality; return a structured report."""
    errors: list[str] = []
    warnings: list[str] = []

    actual_columns = list(df.columns)
    missing_columns = [col for col in EXPECTED_COLUMNS if col not in actual_columns]
    unexpected_columns = [col for col in actual_columns if col not in EXPECTED_COLUMNS]

    if missing_columns:
        errors.append(f"Missing required columns: {missing_columns}")
    if unexpected_columns:
        warnings.append(f"Unexpected columns present: {unexpected_columns}")

    stats = get_data_statistics(df)

    if expected_rows and stats["row_count"] != expected_rows:
        warnings.append(
            f"Row count differs from expected {expected_rows}: got {stats['row_count']}"
        )

    if stats["missing_values_total"] > 0:
        warnings.append(f"Dataset contains {stats['missing_values_total']} missing values")

    if "Class" in df.columns:
        class_values = set(df["Class"].dropna().unique().tolist())
        invalid_class_values = sorted(class_values - EXPECTED_CLASS_VALUES)
        if invalid_class_values:
            errors.append(f"Class contains invalid values: {invalid_class_values}")
        if len(class_values) == 1:
            warnings.append("Class column has only one class present")
    else:
        errors.append("Class column not found")

    is_valid = len(errors) == 0
    return {"is_valid": is_valid, "errors": errors, "warnings": warnings, "statistics": stats}


def save_validation_report(report: dict[str, Any], output_path: str | Path) -> Path:
    """Write validation report to JSON."""
    output = Path(output_path)
    output.parent.mkdir(parents=True, exist_ok=True)
    output.write_text(json.dumps(report, indent=2), encoding="utf-8")
    return output


def run_data_validation(
    file_path: str | Path = "data/raw/creditcard.csv",
    report_path: str | Path = "artifacts/data_validation.json",
) -> dict[str, Any]:
    """Load dataset, validate, persist report, and fail fast on schema errors."""
    df = load_data(file_path)
    report = validate_data(df)
    save_validation_report(report, report_path)
    if not report["is_valid"]:
        raise ValueError(f"Data validation failed: {report['errors']}")
    return report


def _build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Validate fraud dataset schema and quality.")
    parser.add_argument(
        "--data-path",
        default="data/raw/creditcard.csv",
        help="Path to the raw CSV dataset.",
    )
    parser.add_argument(
        "--report-path",
        default="artifacts/data_validation.json",
        help="Path to write the validation report JSON.",
    )
    return parser


def main() -> None:
    args = _build_parser().parse_args()
    report = run_data_validation(args.data_path, args.report_path)
    print("Data validation passed.")
    print(json.dumps(report["statistics"], indent=2))


if __name__ == "__main__":
    main()