Spaces:
Running
Running
| import os | |
| import logging | |
| import secrets | |
| import requests | |
| from typing import Annotated, Generator | |
| from fastapi import ( | |
| FastAPI, | |
| Request, | |
| HTTPException, | |
| Security, | |
| Depends | |
| ) | |
| from fastapi.background import BackgroundTasks | |
| from fastapi.responses import RedirectResponse, JSONResponse | |
| from fastapi.security.api_key import APIKeyHeader | |
| from psycopg.errors import UniqueViolation | |
| from sqlalchemy.exc import IntegrityError | |
| from starlette.status import ( | |
| HTTP_200_OK, | |
| HTTP_403_FORBIDDEN, | |
| HTTP_422_UNPROCESSABLE_ENTITY, | |
| HTTP_500_INTERNAL_SERVER_ERROR, | |
| HTTP_503_SERVICE_UNAVAILABLE) | |
| from dotenv import load_dotenv | |
| from sqlalchemy import text | |
| from sqlalchemy.orm import Session | |
| from src.entity.api.transaction_api import TransactionApi, TransactionProcessingOutput | |
| from src.service.fraud_service import check_for_fraud_api | |
| from src.service.logging_service import setup_logging | |
| from src.service.notification_service import send_notification | |
| from src.repository.common import get_session | |
| from src.repository.fraud_details_repo import insert_fraud | |
| from src.repository.transaction_repo import insert_transaction, fetch_transaction_by_number | |
| # ------------------------------------------------------------------------------ | |
| # Configure logging | |
| setup_logging() | |
| logger = logging.getLogger(__name__) | |
| # ------------------------------------------------------------------------------ | |
| def provide_connection() -> Generator[Session, None, None]: | |
| with get_session() as conn: | |
| yield conn | |
| # ------------------------------------------------------------------------------ | |
| # Load environment variables | |
| load_dotenv() | |
| FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY") | |
| FRAUD_ML_API_KEY = os.getenv("FRAUD_ML_API_KEY") | |
| FRAUD_ML_HEALTHCHECK_ENDPOINT = os.getenv("FRAUD_ML_HEALTHCHECK_ENDPOINT") | |
| safe_clients = ['127.0.0.1'] | |
| api_key_header = APIKeyHeader(name='Authorization', auto_error=False) | |
| async def validate_api_key(request: Request, key: str = Security(api_key_header)): | |
| ''' | |
| Check if the API key is valid | |
| Args: | |
| key (str): The API key to check | |
| Raises: | |
| HTTPException: If the API key is invalid | |
| ''' | |
| if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)): | |
| raise HTTPException( | |
| status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong" | |
| ) | |
| return None | |
| app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None, | |
| title="Fraud Detection Service API") | |
| # ------------------------------------------------------------------------------ | |
| def redirect_to_docs(): | |
| ''' | |
| Redirect to the API documentation. | |
| ''' | |
| return RedirectResponse(url='/docs') | |
| async def get_fraud_status( | |
| transaction_number: str, | |
| db: Annotated[Session, Depends(provide_connection)] | |
| ): | |
| """ | |
| Get the fraud status of a transaction | |
| """ | |
| # Check if the transaction exists in the database | |
| try: | |
| transaction = fetch_transaction_by_number( | |
| db=db, | |
| transaction_number=transaction_number | |
| ) | |
| except ValueError as e: | |
| logger.error(e) | |
| raise HTTPException( | |
| status_code=HTTP_422_UNPROCESSABLE_ENTITY, | |
| detail=f"Transaction {transaction_number} does not exist" | |
| ) | |
| except Exception as e: | |
| logger.error(e) | |
| raise HTTPException( | |
| status_code=HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An error occurred while fetching the transaction. See logs for details." | |
| ) | |
| # Check if the transaction is fraudulent | |
| is_fraud = transaction.is_fraud is True | |
| fraud_score = transaction.fraud_details.fraud_score if transaction.fraud_details else -1 | |
| return { | |
| 'is_fraud': is_fraud, | |
| 'fraud_score': fraud_score | |
| } | |
| # ------------------------------------------------------------------------------ | |
| async def process_transaction( | |
| background_tasks: BackgroundTasks, | |
| transactionApi: TransactionApi, | |
| db: Annotated[Session, Depends(provide_connection)] | |
| ): | |
| """ | |
| Process a transaction | |
| """ | |
| # Check the transaction | |
| if not transactionApi.is_valid(): | |
| raise HTTPException( | |
| status_code=HTTP_422_UNPROCESSABLE_ENTITY, | |
| detail="Transaction is not valid. Check input values." | |
| ) | |
| # Convert the API object to a Transaction object | |
| transaction = transactionApi.to_transaction() | |
| # Process the transaction | |
| try: | |
| # Insert every single transaction into the database | |
| transaction = insert_transaction(db, transaction) | |
| except (UniqueViolation, IntegrityError): | |
| logger.warning("Transaction cannot be inserted in the database - checking existence...") | |
| db.rollback() | |
| return await get_fraud_status( | |
| transaction_number=transaction.transaction_number, | |
| db=db | |
| ) | |
| except Exception as e: | |
| logger.error(e) | |
| raise HTTPException( | |
| status_code=HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An error occurred while processing the transaction. See logs for details." | |
| ) | |
| # Call the fraud detection API | |
| fraud_output = check_for_fraud_api(transaction) | |
| is_fraud = fraud_output.result == 1 | |
| if is_fraud: | |
| insert_fraud( | |
| db=db, | |
| transaction=transaction, | |
| fraud_score=fraud_output.fraud_probability, | |
| model_version=fraud_output.model_metadata['version'] if 'version' in fraud_output.model_metadata else 'unknown' | |
| ) | |
| # Send notification to the user | |
| background_tasks.add_task( | |
| func=send_notification, | |
| transaction_id=transaction.id) | |
| # Return the result | |
| output = { | |
| 'is_fraud': fraud_output.result, | |
| 'fraud_score': fraud_output.fraud_probability | |
| } | |
| return output | |
| # ------------------------------------------------------------------------------ | |
| async def check_health(session: Annotated[Session, Depends(provide_connection)]): | |
| """ | |
| Check all the services in the infrastructure are working | |
| """ | |
| # Check if the database is alive | |
| try: | |
| session.execute(text("SELECT 1")) | |
| except Exception as e: | |
| logger.error(f"DB check failed: {e}") | |
| return JSONResponse(content={"status": "unhealthy"}, status_code=HTTP_503_SERVICE_UNAVAILABLE) | |
| # Check if the fraud detection API is alive | |
| response = requests.get(FRAUD_ML_HEALTHCHECK_ENDPOINT, | |
| headers={ | |
| 'Content-Type': 'application/json', | |
| 'Authorization': FRAUD_ML_API_KEY, | |
| }) | |
| if response.status_code != HTTP_200_OK: | |
| logger.error(f"Fraud detection API check failed: {response.status_code} - {response.text}") | |
| return JSONResponse(content={"status": "unhealthy"}, status_code=HTTP_503_SERVICE_UNAVAILABLE) | |
| return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK) | |