SlimG's picture
move TransactionProcessingOutput to entity file
4e7575f
raw
history blame
7.8 kB
from typing import Optional, Literal, ClassVar
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime
from src.entity.transaction import Transaction
import logging
# Set up logging
logger = logging.getLogger(__name__)
class TransactionApi(BaseModel):
"""
TransactionApi is a class that represents a transaction.
"""
transaction_number: str = Field(..., title="Transaction number", description="The ID of the transaction.")
transaction_timestamp: int = Field(..., title="Timestamp", description="The timestamp of the transaction.")
transaction_amount: float = Field(..., ge=0, title="Amount", description="The transaction amount.")
transaction_category: Optional[str] = Field(None, title="Product category", description="The category of product of the transaction.")
transaction_is_real_fraud: Optional[bool] = Field(None, title="Ground truth", description="True if the transaction really is a fraud (ground truth).")
merchant_name: str = Field(..., title="Name", description="The name of the merchant.")
merchant_latitude: float = Field(..., title="Latitude", description="The latitude of the merchant.")
merchant_longitude: float = Field(..., title="Longitude", description="The longitude of the merchant.")
customer_credit_card_number: int|str = Field(..., title="Customer's credit card", description="The credit card number of the customer.")
customer_gender: Optional[Literal['M', 'F']] = Field(None, title="Customer's gender", description="The gender of the customer.")
customer_first_name: Optional[str] = Field(None, title="Customer's first name", description="The first name of the customer.")
customer_last_name: Optional[str] = Field(None, title="Customer's last name", description="The last name of the customer.")
customer_date_of_birth: Optional[str] = Field(None, title="Customer's date of birth", description="The date of birth of the customer.")
customer_job: Optional[str] = Field(None, title="Customer's job", description="The job of the customer.")
customer_street: Optional[str] = Field(None, title="Customer's street", description="The street of the customer.")
customer_city: Optional[str] = Field(None, title="Customer's city", description="The city of the customer.")
customer_state: Optional[str] = Field(None, title="Customer's state", description="The state of the customer.")
customer_postal_code: Optional[int] = Field(None, title="Customer's postal code", description="The postal code of the customer.")
customer_latitude: float = Field(..., title="Customer's latitude", description="The latitude of the customer.")
customer_longitude: float = Field(..., title="Customer's longitude", description="The longitude of the customer.")
customer_city_population: Optional[int] = Field(None, ge=0, title="Customer's city population", description="The population of the city.")
def is_valid(self) -> bool:
"""
Check if the transaction is valid.
"""
logger.debug("Validating transaction...")
now = datetime.now().replace(microsecond=0)
# The transaction amount should be greater than 0
if self.transaction_amount < 0:
logger.error("Transaction amount is negative.")
return False
# Validate the transaction timestamp
try:
# Unix timestamps may not be very precise, so we divide by 1000 to get seconds
dt_object = datetime.fromtimestamp(self.transaction_timestamp // 1000)
if dt_object > now:
logger.error(f"Transaction timestamp is in the future ({dt_object} > {now}).")
return False
except ValueError:
logger.error("Invalid transaction timestamp.")
return False
# Validate the customer date of birth
# The customer should be 10 years old at least to possess a credit card
# If the customer date of birth is not provided, we assume the customer is older than 10
# If the customer date of birth is provided, we check if the customer is older than 10
if self.customer_date_of_birth:
try:
dob = datetime.strptime(self.customer_date_of_birth, "%Y-%m-%d")
age = (now - dob).days // 365
if age < 10:
logger.error("Customer is younger than 10 years old.")
return False
except ValueError:
logger.error("Invalid customer date of birth format. Expected YYYY-MM-DD.")
return False
return True
def to_transaction(self) -> Transaction:
return Transaction(
transaction_number=self.transaction_number,
transaction_amount=self.transaction_amount,
transaction_datetime=datetime.fromtimestamp(self.transaction_timestamp / 1000),
transaction_category=self.transaction_category,
merchant_name=self.merchant_name,
merchant_address_latitude=self.merchant_latitude,
merchant_address_longitude=self.merchant_longitude,
customer_credit_card_number=self.customer_credit_card_number,
customer_gender=self.customer_gender,
customer_firstname=self.customer_first_name,
customer_lastname=self.customer_last_name,
customer_dob=self.customer_date_of_birth,
customer_job=self.customer_job,
customer_address_street=self.customer_street,
customer_address_city=self.customer_city,
customer_address_state=self.customer_state,
customer_address_zip=self.customer_postal_code,
customer_address_latitude=self.customer_latitude,
customer_address_longitude=self.customer_longitude,
customer_address_city_population=self.customer_city_population,
is_real_fraud=self.transaction_is_real_fraud,
)
class Config:
json_schema_extra = {
"example": {
"transaction_number": "123456789",
"transaction_timestamp": 1633036800,
"transaction_amount": 100.0,
"transaction_category": "Electronics",
"merchant_name": "Best Buy",
"merchant_latitude": 37.7749,
"merchant_longitude": -122.4194,
"customer_credit_card_number": "4111111111111111",
"customer_gender": "M",
"customer_first_name": "John",
"customer_last_name": "Doe",
"customer_date_of_birth": "1980-01-01",
"customer_job": "Engineer",
"customer_street": "123 Main St",
"customer_city": "San Francisco",
"customer_state": "CA",
"customer_postal_code": 94105,
"customer_latitude": 37.7749,
"customer_longitude": -122.4194,
"customer_city_population": 870000,
}
}
class TransactionProcessingOutput(BaseModel):
"""
TransactionProcessingOutput is a class that represents the output of the transaction processing endpoint.
It contains the fraud detection result and the fraud score.
This class is used as the response model for the /transaction/process endpoint.
"""
is_fraud: int = Field(description="The prediction result. 1 if the transaction is detected as fraudulent, 0 otherwise.")
fraud_score: float = Field(description="Probability of transaction being fraudulent.")
model_config: ClassVar[ConfigDict] = ConfigDict(
json_schema_extra = {
"example": {
"is_fraud": 1,
"fraud_score": 0.85
}
})