Spaces:
Running
Running
| from typing import Optional, Literal, ClassVar | |
| from pydantic import BaseModel, Field, ConfigDict | |
| from datetime import datetime | |
| from src.entity.transaction import Transaction | |
| import logging | |
| # Set up logging | |
| logger = logging.getLogger(__name__) | |
| class TransactionApi(BaseModel): | |
| """ | |
| TransactionApi is a class that represents a transaction. | |
| """ | |
| transaction_number: str = Field(..., title="Transaction number", description="The ID of the transaction.") | |
| transaction_timestamp: int = Field(..., title="Timestamp", description="The timestamp of the transaction.") | |
| transaction_amount: float = Field(..., ge=0, title="Amount", description="The transaction amount.") | |
| transaction_category: Optional[str] = Field(None, title="Product category", description="The category of product of the transaction.") | |
| transaction_is_real_fraud: Optional[bool] = Field(None, title="Ground truth", description="True if the transaction really is a fraud (ground truth).") | |
| merchant_name: str = Field(..., title="Name", description="The name of the merchant.") | |
| merchant_latitude: float = Field(..., title="Latitude", description="The latitude of the merchant.") | |
| merchant_longitude: float = Field(..., title="Longitude", description="The longitude of the merchant.") | |
| customer_credit_card_number: int|str = Field(..., title="Customer's credit card", description="The credit card number of the customer.") | |
| customer_gender: Optional[Literal['M', 'F']] = Field(None, title="Customer's gender", description="The gender of the customer.") | |
| customer_first_name: Optional[str] = Field(None, title="Customer's first name", description="The first name of the customer.") | |
| customer_last_name: Optional[str] = Field(None, title="Customer's last name", description="The last name of the customer.") | |
| customer_date_of_birth: Optional[str] = Field(None, title="Customer's date of birth", description="The date of birth of the customer.") | |
| customer_job: Optional[str] = Field(None, title="Customer's job", description="The job of the customer.") | |
| customer_street: Optional[str] = Field(None, title="Customer's street", description="The street of the customer.") | |
| customer_city: Optional[str] = Field(None, title="Customer's city", description="The city of the customer.") | |
| customer_state: Optional[str] = Field(None, title="Customer's state", description="The state of the customer.") | |
| customer_postal_code: Optional[int] = Field(None, title="Customer's postal code", description="The postal code of the customer.") | |
| customer_latitude: float = Field(..., title="Customer's latitude", description="The latitude of the customer.") | |
| customer_longitude: float = Field(..., title="Customer's longitude", description="The longitude of the customer.") | |
| customer_city_population: Optional[int] = Field(None, ge=0, title="Customer's city population", description="The population of the city.") | |
| def is_valid(self) -> bool: | |
| """ | |
| Check if the transaction is valid. | |
| """ | |
| logger.debug("Validating transaction...") | |
| now = datetime.now().replace(microsecond=0) | |
| # The transaction amount should be greater than 0 | |
| if self.transaction_amount < 0: | |
| logger.error("Transaction amount is negative.") | |
| return False | |
| # Validate the transaction timestamp | |
| try: | |
| # Unix timestamps may not be very precise, so we divide by 1000 to get seconds | |
| dt_object = datetime.fromtimestamp(self.transaction_timestamp // 1000) | |
| if dt_object > now: | |
| logger.error(f"Transaction timestamp is in the future ({dt_object} > {now}).") | |
| return False | |
| except ValueError: | |
| logger.error("Invalid transaction timestamp.") | |
| return False | |
| # Validate the customer date of birth | |
| # The customer should be 10 years old at least to possess a credit card | |
| # If the customer date of birth is not provided, we assume the customer is older than 10 | |
| # If the customer date of birth is provided, we check if the customer is older than 10 | |
| if self.customer_date_of_birth: | |
| try: | |
| dob = datetime.strptime(self.customer_date_of_birth, "%Y-%m-%d") | |
| age = (now - dob).days // 365 | |
| if age < 10: | |
| logger.error("Customer is younger than 10 years old.") | |
| return False | |
| except ValueError: | |
| logger.error("Invalid customer date of birth format. Expected YYYY-MM-DD.") | |
| return False | |
| return True | |
| def to_transaction(self) -> Transaction: | |
| return Transaction( | |
| transaction_number=self.transaction_number, | |
| transaction_amount=self.transaction_amount, | |
| transaction_datetime=datetime.fromtimestamp(self.transaction_timestamp / 1000), | |
| transaction_category=self.transaction_category, | |
| merchant_name=self.merchant_name, | |
| merchant_address_latitude=self.merchant_latitude, | |
| merchant_address_longitude=self.merchant_longitude, | |
| customer_credit_card_number=self.customer_credit_card_number, | |
| customer_gender=self.customer_gender, | |
| customer_firstname=self.customer_first_name, | |
| customer_lastname=self.customer_last_name, | |
| customer_dob=self.customer_date_of_birth, | |
| customer_job=self.customer_job, | |
| customer_address_street=self.customer_street, | |
| customer_address_city=self.customer_city, | |
| customer_address_state=self.customer_state, | |
| customer_address_zip=self.customer_postal_code, | |
| customer_address_latitude=self.customer_latitude, | |
| customer_address_longitude=self.customer_longitude, | |
| customer_address_city_population=self.customer_city_population, | |
| is_real_fraud=self.transaction_is_real_fraud, | |
| ) | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "transaction_number": "123456789", | |
| "transaction_timestamp": 1633036800, | |
| "transaction_amount": 100.0, | |
| "transaction_category": "Electronics", | |
| "merchant_name": "Best Buy", | |
| "merchant_latitude": 37.7749, | |
| "merchant_longitude": -122.4194, | |
| "customer_credit_card_number": "4111111111111111", | |
| "customer_gender": "M", | |
| "customer_first_name": "John", | |
| "customer_last_name": "Doe", | |
| "customer_date_of_birth": "1980-01-01", | |
| "customer_job": "Engineer", | |
| "customer_street": "123 Main St", | |
| "customer_city": "San Francisco", | |
| "customer_state": "CA", | |
| "customer_postal_code": 94105, | |
| "customer_latitude": 37.7749, | |
| "customer_longitude": -122.4194, | |
| "customer_city_population": 870000, | |
| } | |
| } | |
| class TransactionProcessingOutput(BaseModel): | |
| """ | |
| TransactionProcessingOutput is a class that represents the output of the transaction processing endpoint. | |
| It contains the fraud detection result and the fraud score. | |
| This class is used as the response model for the /transaction/process endpoint. | |
| """ | |
| is_fraud: int = Field(description="The prediction result. 1 if the transaction is detected as fraudulent, 0 otherwise.") | |
| fraud_score: float = Field(description="Probability of transaction being fraudulent.") | |
| model_config: ClassVar[ConfigDict] = ConfigDict( | |
| json_schema_extra = { | |
| "example": { | |
| "is_fraud": 1, | |
| "fraud_score": 0.85 | |
| } | |
| }) |