File size: 9,915 Bytes
8a08300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
Data Ingestion Module

This module handles loading and validating credit card transaction data.
It uses Pydantic for schema validation to ensure data quality before processing.

Author: PayShield-ML Team
"""

from datetime import datetime
from pathlib import Path
from typing import Literal, Optional, Union

import pandas as pd
import pyarrow.parquet as pq
from pydantic import BaseModel, Field, field_validator, model_validator

from src.features.constants import category_names, job_names


class TransactionSchema(BaseModel):
    """
    Pydantic model for validating individual transaction records.

    Enforces strict business rules:
    - Transaction amounts must be positive
    - Coordinates must be valid (lat: [-90, 90], long: [-180, 180])
    - Category and job must be from known sets
    - Timestamps must be valid

    Attributes:
        trans_date_trans_time: Transaction timestamp
        cc_num: Credit card number (PII - handle with care)
        merchant: Merchant name
        category: Transaction category (e.g., 'grocery_pos', 'gas_transport')
        amt: Transaction amount in USD
        first: Customer first name
        last: Customer last name
        gender: Customer gender
        street: Street address
        city: City name
        state: State code (2 letters)
        zip: ZIP code
        lat: Customer latitude (-90 to 90)
        long: Customer longitude (-180 to 180)
        city_pop: City population
        job: Customer job title
        dob: Date of birth
        trans_num: Unique transaction identifier
        unix_time: Unix timestamp
        merch_lat: Merchant latitude
        merch_long: Merchant longitude
        is_fraud: Fraud label (0 or 1)
    """

    # Transaction Details
    trans_date_trans_time: str = Field(
        ..., description="Transaction timestamp in format 'YYYY-MM-DD HH:MM:SS'"
    )
    cc_num: int = Field(..., description="Credit card number", gt=0)
    merchant: str = Field(..., min_length=1, description="Merchant name")
    category: str = Field(..., description="Transaction category")
    amt: float = Field(..., gt=0.0, description="Transaction amount (must be positive)")

    # Customer Information
    first: str = Field(..., min_length=1, description="First name")
    last: str = Field(..., min_length=1, description="Last name")
    gender: Literal["M", "F"] = Field(..., description="Gender")
    street: str = Field(..., description="Street address")
    city: str = Field(..., description="City")
    state: str = Field(..., min_length=2, max_length=2, description="State code")
    zip: int = Field(..., ge=1000, le=99999, description="ZIP code")
    lat: float = Field(..., ge=-90.0, le=90.0, description="Customer latitude")
    long: float = Field(..., ge=-180.0, le=180.0, description="Customer longitude")
    city_pop: int = Field(..., ge=0, description="City population")
    job: str = Field(..., description="Job title")
    dob: str = Field(..., description="Date of birth in format 'YYYY-MM-DD'")

    # Transaction Metadata
    trans_num: str = Field(..., description="Unique transaction ID (hex string)")
    unix_time: int = Field(..., gt=0, description="Unix timestamp")
    merch_lat: float = Field(..., ge=-90.0, le=90.0, description="Merchant latitude")
    merch_long: float = Field(..., ge=-180.0, le=180.0, description="Merchant longitude")
    is_fraud: Literal[0, 1] = Field(..., description="Fraud indicator")

    @field_validator("category")
    @classmethod
    def validate_category(cls, v: str) -> str:
        """Ensure category is from known set."""
        if v not in category_names:
            raise ValueError(
                f"Invalid category '{v}'. Must be one of: {', '.join(category_names[:5])}..."
            )
        return v

    @field_validator("job")
    @classmethod
    def validate_job(cls, v: str) -> str:
        """Ensure job is from known set."""
        if v not in job_names:
            raise ValueError(
                f"Invalid job '{v}'. Must be one of the {len(job_names)} known job titles"
            )
        return v

    @field_validator("trans_date_trans_time")
    @classmethod
    def validate_timestamp(cls, v: str) -> str:
        """Ensure timestamp is valid."""
        try:
            datetime.strptime(v, "%Y-%m-%d %H:%M:%S")
        except ValueError as e:
            raise ValueError(
                f"Invalid timestamp format '{v}'. Expected 'YYYY-MM-DD HH:MM:SS'"
            ) from e
        return v

    @field_validator("dob")
    @classmethod
    def validate_dob(cls, v: str) -> str:
        """Ensure date of birth is valid."""
        try:
            datetime.strptime(v, "%Y-%m-%d")
        except ValueError as e:
            raise ValueError(f"Invalid date of birth format '{v}'. Expected 'YYYY-MM-DD'") from e
        return v

    @model_validator(mode="after")
    def validate_distance_sanity(self) -> "TransactionSchema":
        """
        Sanity check: Ensure customer and merchant coordinates are reasonable.
        This catches data corruption where lat/long might be swapped.
        """
        # Check if coordinates are swapped (common data error)
        if abs(self.lat) > 50 and abs(self.long) < 50:
            # Likely US-based dataset, this pattern suggests swap
            raise ValueError(
                f"Suspicious coordinates: lat={self.lat}, long={self.long}. "
                f"Check if latitude and longitude are swapped."
            )
        return self


class InferenceTransactionSchema(BaseModel):
    """
    Simplified schema for real-time inference requests.

    Only includes features needed for prediction (no PII like names/addresses).
    This is what the API endpoint expects.

    Attributes:
        user_id: Internal user identifier (replaces cc_num for privacy)
        amt: Transaction amount
        lat: User's last known latitude
        long: User's last known longitude
        category: Transaction category
        job: User's job (from profile)
        merch_lat: Merchant latitude
        merch_long: Merchant longitude
        unix_time: Transaction timestamp (Unix epoch)
    """

    user_id: str = Field(..., min_length=1, description="User identifier")
    amt: float = Field(..., gt=0.0, description="Transaction amount")
    lat: float = Field(..., ge=-90.0, le=90.0, description="User latitude")
    long: float = Field(..., ge=-180.0, le=180.0, description="User longitude")
    category: str = Field(..., description="Transaction category")
    job: str = Field(..., description="User job title")
    merch_lat: float = Field(..., ge=-90.0, le=90.0, description="Merchant latitude")
    merch_long: float = Field(..., ge=-180.0, le=180.0, description="Merchant longitude")
    unix_time: int = Field(..., gt=0, description="Transaction timestamp")

    @field_validator("category")
    @classmethod
    def validate_category(cls, v: str) -> str:
        """Ensure category is from known set."""
        if v not in category_names:
            raise ValueError(f"Invalid category '{v}'. Must be one of: {', '.join(category_names)}")
        return v

    @field_validator("job")
    @classmethod
    def validate_job(cls, v: str) -> str:
        """Ensure job is from known set."""
        if v not in job_names:
            raise ValueError(f"Invalid job '{v}'. Not in approved job list")
        return v


def load_dataset(
    file_path: Union[str, Path], validate: bool = True, sample_n: Optional[int] = None
) -> pd.DataFrame:
    """
    Load credit card fraud dataset from CSV or Parquet with optional validation.

    This function handles both training data loads (with validation) and
    production loads (validation optional for speed).

    Args:
        file_path: Path to CSV or Parquet file
        validate: If True, validate each row against TransactionSchema.
                 Set to False for faster loading in production.
        sample_n: If specified, return only N randomly sampled rows (for testing)

    Returns:
        DataFrame with validated transaction data

    Raises:
        FileNotFoundError: If file doesn't exist
        ValueError: If validation fails for any row

    Example:
        >>> # Load and validate training data
        >>> df = load_dataset("fraudTrain.csv", validate=True)
        >>>
        >>> # Fast load for inference (skip validation)
        >>> df = load_dataset("fraudTrain.parquet", validate=False)
        >>>
        >>> # Load sample for testing
        >>> df_sample = load_dataset("fraudTrain.csv", sample_n=1000)
    """
    file_path = Path(file_path)

    if not file_path.exists():
        raise FileNotFoundError(f"Dataset not found: {file_path}")

    # Load based on file extension
    if file_path.suffix == ".csv":
        df = pd.read_csv(file_path)
    elif file_path.suffix == ".parquet":
        df = pd.read_parquet(file_path)
    else:
        raise ValueError(f"Unsupported file format: {file_path.suffix}. Use .csv or .parquet")

    # Sample if requested
    if sample_n is not None:
        df = df.sample(n=min(sample_n, len(df)), random_state=42)

    # Validate if requested
    if validate:
        print(f"Validating {len(df):,} transactions...")
        errors = []

        for idx, row in df.iterrows():
            try:
                TransactionSchema(**row.to_dict())
            except Exception as e:
                errors.append(f"Row {idx}: {str(e)}")
                if len(errors) >= 10:  # Stop after 10 errors to avoid spam
                    errors.append("... (stopped after 10 errors)")
                    break

        if errors:
            error_msg = "\n".join(errors)
            raise ValueError(f"Validation failed:\n{error_msg}")

        print(f"✓ All {len(df):,} transactions validated successfully")

    return df


__all__ = [
    "TransactionSchema",
    "InferenceTransactionSchema",
    "load_dataset",
]