Spaces:
Runtime error
Runtime error
| """Three-tier validation suite for DIME parquet conversions.""" | |
| from __future__ import annotations | |
| import csv | |
| import gzip | |
| import random | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| import pyarrow.compute as pc | |
| import pyarrow.parquet as pq | |
| from .exceptions import ( | |
| ChecksumMismatchError, | |
| RowCountMismatchError, | |
| SampleMismatchError, | |
| ) | |
| if TYPE_CHECKING: | |
| from .converter import StreamingStats | |
| class ValidationResult: | |
| """Results from validation suite.""" | |
| row_count_valid: bool = False | |
| row_count_expected: int = 0 | |
| row_count_actual: int = 0 | |
| checksum_valid: bool = False | |
| sum_column_name: str | None = None # Name of column used for sum validation | |
| sum_column_expected: float = 0.0 | |
| sum_column_actual: float = 0.0 | |
| non_null_counts: dict[str, tuple[int, int]] = field(default_factory=dict) | |
| sample_valid: bool = False | |
| sample_size: int = 0 | |
| def all_valid(self) -> bool: | |
| return self.row_count_valid and self.checksum_valid and self.sample_valid | |
| def validate_row_count( | |
| source_path: Path, | |
| output_path: Path, | |
| expected_count: int, | |
| ) -> ValidationResult: | |
| """ | |
| Tier 1 validation: Verify row counts match. | |
| This is the fastest validation - just reads parquet metadata. | |
| """ | |
| meta = pq.read_metadata(output_path) | |
| actual_count = meta.num_rows | |
| result = ValidationResult( | |
| row_count_expected=expected_count, | |
| row_count_actual=actual_count, | |
| ) | |
| if actual_count != expected_count: | |
| raise RowCountMismatchError( | |
| source_path=source_path, | |
| message="Row count validation failed", | |
| expected_rows=expected_count, | |
| actual_rows=actual_count, | |
| ) | |
| result.row_count_valid = True | |
| return result | |
| def validate_checksums( | |
| source_path: Path, | |
| output_path: Path, | |
| source_stats: StreamingStats, | |
| result: ValidationResult, | |
| sum_column: str | None = "amount", | |
| key_columns: list[str] | None = None, | |
| ) -> ValidationResult: | |
| """ | |
| Tier 2 validation: Verify column checksums. | |
| Compares streaming stats (accumulated during conversion) against | |
| parquet data using memory-efficient column-level reads. | |
| Compares: | |
| - Sum of configurable column (detects truncation/conversion errors) | |
| - Non-null counts for key columns (detects dropped data) | |
| Args: | |
| source_path: Path to source CSV file | |
| output_path: Path to output parquet file | |
| source_stats: Statistics accumulated during conversion | |
| result: ValidationResult to update | |
| sum_column: Column to sum for checksum (None to skip sum validation) | |
| key_columns: Columns to check non-null counts | |
| """ | |
| if key_columns is None: | |
| key_columns = ["transaction.id", "bonica.cid", "contributor.name", "amount"] | |
| result.sum_column_name = sum_column | |
| # Checksum 1: Sum of configurable column | |
| if sum_column: | |
| source_sum = source_stats.sum_column_value | |
| # Read only the sum column from parquet (memory-efficient) | |
| sum_table = pq.read_table(output_path, columns=[sum_column]) | |
| parquet_sum = pc.sum(sum_table.column(sum_column)).as_py() or 0.0 | |
| result.sum_column_expected = source_sum | |
| result.sum_column_actual = parquet_sum | |
| # Allow tiny floating point tolerance | |
| if abs(source_sum - parquet_sum) > 0.01: | |
| raise ChecksumMismatchError( | |
| source_path=source_path, | |
| message=f"{sum_column} sum mismatch", | |
| column_name=sum_column, | |
| expected_value=source_sum, | |
| actual_value=parquet_sum, | |
| ) | |
| # Checksum 2: Non-null counts for key columns (column-level reads) | |
| for col in key_columns: | |
| source_count = source_stats.non_null_counts.get(col, 0) | |
| # Read only this column from parquet | |
| col_table = pq.read_table(output_path, columns=[col]) | |
| parquet_count = pc.count(col_table.column(col)).as_py() | |
| result.non_null_counts[col] = (source_count, parquet_count) | |
| if source_count != parquet_count: | |
| raise ChecksumMismatchError( | |
| source_path=source_path, | |
| message=f"Non-null count mismatch for {col}", | |
| column_name=col, | |
| expected_value=source_count, | |
| actual_value=parquet_count, | |
| ) | |
| result.checksum_valid = True | |
| return result | |
| def validate_sample_rows( | |
| source_path: Path, | |
| output_path: Path, | |
| sample_size: int, | |
| result: ValidationResult, | |
| ) -> ValidationResult: | |
| """ | |
| Tier 3 validation: Compare random sample of rows. | |
| Uses memory-efficient batch reading: iterates through parquet in batches, | |
| capturing sample rows as encountered rather than loading entire file. | |
| """ | |
| # Get total row count from parquet metadata | |
| meta = pq.read_metadata(output_path) | |
| total_rows = meta.num_rows | |
| # Select random row indices | |
| actual_sample_size = min(sample_size, total_rows) | |
| sample_indices = sorted(random.sample(range(total_rows), actual_sample_size)) | |
| # Read source CSV rows at sample indices | |
| source_rows = _read_csv_rows_at_indices(source_path, sample_indices) | |
| # Read parquet in batches, capturing sample rows (memory-efficient) | |
| parquet_file = pq.ParquetFile(output_path) | |
| schema_names = parquet_file.schema_arrow.names | |
| # Map from sample index to its position in our results | |
| index_to_position = {idx: pos for pos, idx in enumerate(sample_indices)} | |
| parquet_sample_rows: list[dict | None] = [None] * len(sample_indices) | |
| current_row = 0 | |
| for batch in parquet_file.iter_batches(): | |
| batch_end = current_row + batch.num_rows | |
| # Check if any sample indices fall in this batch | |
| for sample_idx in sample_indices: | |
| if current_row <= sample_idx < batch_end: | |
| # Extract this row from the batch | |
| local_idx = sample_idx - current_row | |
| row_dict = {} | |
| for col_name in schema_names: | |
| row_dict[col_name] = batch.column(col_name)[local_idx].as_py() | |
| parquet_sample_rows[index_to_position[sample_idx]] = row_dict | |
| current_row = batch_end | |
| # Early exit if we have all samples | |
| if all(r is not None for r in parquet_sample_rows): | |
| break | |
| # Compare | |
| result.sample_size = len(sample_indices) | |
| for i, row_idx in enumerate(sample_indices): | |
| source_row = source_rows[i] | |
| parquet_row = parquet_sample_rows[i] | |
| for col_name in schema_names: | |
| source_val = source_row.get(col_name) | |
| parquet_val = parquet_row.get(col_name) if parquet_row else None | |
| # Normalize for comparison | |
| source_normalized = _normalize_value(source_val) | |
| parquet_normalized = _normalize_value(parquet_val) | |
| if not _values_equal(source_normalized, parquet_normalized): | |
| raise SampleMismatchError( | |
| source_path=source_path, | |
| message="Sample row mismatch", | |
| row_index=row_idx, | |
| column_name=col_name, | |
| expected_value=str(source_normalized), | |
| actual_value=str(parquet_normalized), | |
| ) | |
| result.sample_valid = True | |
| return result | |
| def _read_csv_rows_at_indices(path: Path, indices: list[int]) -> list[dict]: | |
| """Read specific rows from CSV by index.""" | |
| rows = [] | |
| indices_set = set(indices) | |
| index_to_position = {idx: pos for pos, idx in enumerate(sorted(indices))} | |
| # Pre-allocate result list | |
| rows = [None] * len(indices) | |
| with gzip.open(path, "rt", encoding="latin1") as f: | |
| reader = csv.DictReader(f, doublequote=True) | |
| for i, row in enumerate(reader): | |
| if i in indices_set: | |
| rows[index_to_position[i]] = row | |
| if all(r is not None for r in rows): | |
| break | |
| return rows | |
| def _normalize_value(val): | |
| """Normalize values for comparison.""" | |
| import math | |
| if val is None or val == "" or val == "\\N": | |
| return None | |
| if isinstance(val, float): | |
| # Treat NaN as None for comparison purposes | |
| if math.isnan(val): | |
| return None | |
| # Round to avoid floating point comparison issues | |
| return round(val, 6) | |
| if isinstance(val, str): | |
| stripped = val.strip() | |
| # Treat string 'nan' as None (case-insensitive) | |
| if stripped.lower() == "nan": | |
| return None | |
| return stripped | |
| return val | |
| def _values_equal(a, b) -> bool: | |
| """Compare two normalized values.""" | |
| if a is None and b is None: | |
| return True | |
| if a is None or b is None: | |
| return False | |
| # Try numeric comparison if either value looks numeric | |
| try: | |
| a_float = float(a) if isinstance(a, str) else a | |
| b_float = float(b) if isinstance(b, str) else b | |
| if isinstance(a_float, (int, float)) and isinstance(b_float, (int, float)): | |
| return abs(float(a_float) - float(b_float)) < 0.000001 | |
| except (ValueError, TypeError): | |
| pass | |
| # String comparison | |
| return str(a) == str(b) | |