sghorbal commited on
Commit
434466a
·
1 Parent(s): cca1a3f

connect to the ML prediction API

Browse files
.env.example CHANGED
@@ -4,6 +4,11 @@ DATABASE_URL=
4
  # If set, protects the API from unauthorized called
5
  FASTAPI_API_KEY=
6
 
 
 
 
 
 
7
  # Mail notification configurations
8
  RECEIVER_EMAIL="jedha.fraud@yopmail.com"
9
 
 
4
  # If set, protects the API from unauthorized called
5
  FASTAPI_API_KEY=
6
 
7
+ # API of the ML model
8
+ FRAUD_ML_API_KEY=
9
+ FRAUD_ML_HEALTHCHECK_ENDPOINT=
10
+ FRAUD_ML_PREDICTION_ENDPOINT=
11
+
12
  # Mail notification configurations
13
  RECEIVER_EMAIL="jedha.fraud@yopmail.com"
14
 
src/entity/api/fraud_prediction_api.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ import logging
3
+
4
+ # Set up logging
5
+ logging.basicConfig(level=logging.INFO)
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class FraudPredictionInput(BaseModel):
9
+ """
10
+ FraudPredictionInput is a class that represents the input data for fraud prediction.
11
+ """
12
+ transaction_category: str
13
+ transaction_amount: float
14
+ customer_job: str
15
+ customer_address_state: str
16
+ customer_address_city: str
17
+ customer_address_city_population: int
18
+
19
+ class FraudPredictionOutput(BaseModel):
20
+ """
21
+ FraudPredictionOutput is a class that represents the output data for fraud prediction.
22
+ """
23
+ result: int
24
+ fraud_probability: float
25
+ model_metadata: dict
src/main.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import logging
3
  import secrets
 
4
  from typing import Annotated
5
  from fastapi import (
6
  FastAPI,
@@ -10,19 +11,22 @@ from fastapi import (
10
  Depends
11
  )
12
  from fastapi.background import BackgroundTasks
13
- from fastapi.responses import RedirectResponse
14
  from fastapi.security.api_key import APIKeyHeader
15
  from pydantic import BaseModel, Field
16
  from psycopg.errors import UniqueViolation, IntegrityError
17
  from starlette.status import (
 
18
  HTTP_403_FORBIDDEN,
19
  HTTP_422_UNPROCESSABLE_ENTITY,
20
- HTTP_500_INTERNAL_SERVER_ERROR)
 
21
  from dotenv import load_dotenv
22
  from src.entity.api.transaction_api import TransactionApi
 
23
  from sqlalchemy.orm import Session
24
 
25
- from src.service.fraud_service import check_for_fraud
26
  from src.service.notification_service import send_notification
27
  from src.repository.common import get_session
28
  from src.repository.fraud_details_repo import insert_fraud
@@ -32,6 +36,8 @@ from src.repository.transaction_repo import insert_transaction
32
 
33
  load_dotenv()
34
  FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
 
 
35
  safe_clients = ['127.0.0.1']
36
 
37
  api_key_header = APIKeyHeader(name='Authorization', auto_error=False)
@@ -120,15 +126,16 @@ async def process_transaction(
120
  detail="An error occurred while processing the transaction. See logs for details."
121
  )
122
 
123
- # Check for fraud
124
- is_fraud = check_for_fraud(transaction)
 
125
 
126
  if is_fraud:
127
  insert_fraud(
128
  db=db,
129
  transaction=transaction,
130
- fraud_score=0.5,
131
- model_version='latest'
132
  )
133
 
134
  # Send notification to the user
@@ -138,8 +145,8 @@ async def process_transaction(
138
 
139
  # Return the result
140
  output = {
141
- 'is_fraud': 1 if is_fraud else 0,
142
- 'fraud_score': 0.5 if is_fraud else 0.0
143
  }
144
 
145
  return output
@@ -150,18 +157,22 @@ async def check_health(session: Annotated[Session, Depends(get_session)]):
150
  """
151
  Check all the services in the infrastructure are working
152
  """
153
- healthy = 0
154
- unhealthy = 1
155
-
156
- # DB check
157
- db_status = False
158
  try:
159
- session.execute("SELECT 1")
160
- db_status = True
161
- except Exception:
162
- pass
163
-
164
- if db_status:
165
- return healthy
166
- else:
167
- return unhealthy
 
 
 
 
 
 
 
 
 
1
  import os
2
  import logging
3
  import secrets
4
+ import requests
5
  from typing import Annotated
6
  from fastapi import (
7
  FastAPI,
 
11
  Depends
12
  )
13
  from fastapi.background import BackgroundTasks
14
+ from fastapi.responses import RedirectResponse, JSONResponse
15
  from fastapi.security.api_key import APIKeyHeader
16
  from pydantic import BaseModel, Field
17
  from psycopg.errors import UniqueViolation, IntegrityError
18
  from starlette.status import (
19
+ HTTP_200_OK,
20
  HTTP_403_FORBIDDEN,
21
  HTTP_422_UNPROCESSABLE_ENTITY,
22
+ HTTP_500_INTERNAL_SERVER_ERROR,
23
+ HTTP_503_SERVICE_UNAVAILABLE)
24
  from dotenv import load_dotenv
25
  from src.entity.api.transaction_api import TransactionApi
26
+ from sqlalchemy import text
27
  from sqlalchemy.orm import Session
28
 
29
+ from src.service.fraud_service import check_for_fraud_api
30
  from src.service.notification_service import send_notification
31
  from src.repository.common import get_session
32
  from src.repository.fraud_details_repo import insert_fraud
 
36
 
37
  load_dotenv()
38
  FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
39
+ FRAUD_ML_API_KEY = os.getenv("FRAUD_ML_API_KEY")
40
+ FRAUD_ML_HEALTHCHECK_ENDPOINT = os.getenv("FRAUD_ML_HEALTHCHECK_ENDPOINT")
41
  safe_clients = ['127.0.0.1']
42
 
43
  api_key_header = APIKeyHeader(name='Authorization', auto_error=False)
 
126
  detail="An error occurred while processing the transaction. See logs for details."
127
  )
128
 
129
+ # Call the fraud detection API
130
+ fraud_output = check_for_fraud_api(transaction)
131
+ is_fraud = fraud_output.result == 1
132
 
133
  if is_fraud:
134
  insert_fraud(
135
  db=db,
136
  transaction=transaction,
137
+ fraud_score=fraud_output.fraud_probability,
138
+ model_version=fraud_output.model_metadata['model_version'] if 'model_version' in fraud_output.model_metadata else 'unknown'
139
  )
140
 
141
  # Send notification to the user
 
145
 
146
  # Return the result
147
  output = {
148
+ 'is_fraud': fraud_output.result,
149
+ 'fraud_score': fraud_output.fraud_probability
150
  }
151
 
152
  return output
 
157
  """
158
  Check all the services in the infrastructure are working
159
  """
160
+ # Check if the database is alive
 
 
 
 
161
  try:
162
+ session.execute(text("SELECT 1"))
163
+ except Exception as e:
164
+ logging.error(f"DB check failed: {e}")
165
+ return JSONResponse(content={"status": "unhealthy"}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
166
+
167
+ # Check if the fraud detection API is alive
168
+ response = requests.get(FRAUD_ML_HEALTHCHECK_ENDPOINT,
169
+ headers={
170
+ 'Content-Type': 'application/json',
171
+ 'Authorization': FRAUD_ML_API_KEY,
172
+ })
173
+
174
+ if response.status_code != HTTP_200_OK:
175
+ logging.error(f"Fraud detection API check failed: {response.status_code} - {response.text}")
176
+ return JSONResponse(content={"status": "unhealthy"}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
177
+
178
+ return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK)
src/service/fraud_service.py CHANGED
@@ -1,10 +1,56 @@
 
1
  import logging
 
 
2
  from src.entity.transaction import Transaction
 
 
 
 
 
 
 
 
 
3
 
4
  # Configure logging
5
  logging.basicConfig(level=logging.DEBUG)
6
  logger = logging.getLogger(__name__)
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def check_for_fraud(transaction: Transaction) -> bool:
10
  """
 
1
+ import os
2
  import logging
3
+ from typing import Optional
4
+ import requests
5
  from src.entity.transaction import Transaction
6
+ from src.entity.api.fraud_prediction_api import FraudPredictionInput, FraudPredictionOutput
7
+ from dotenv import load_dotenv
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # Set up the API key and endpoint
13
+ FRAUD_ML_API_KEY = os.getenv("FRAUD_ML_API_KEY")
14
+ FRAUD_ML_PREDICTION_ENDPOINT = os.getenv("FRAUD_ML_PREDICTION_ENDPOINT")
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.DEBUG)
18
  logger = logging.getLogger(__name__)
19
 
20
+ def check_for_fraud_api(transaction: Transaction) -> Optional[FraudPredictionOutput]:
21
+ """
22
+ Check for fraud in the transaction API.
23
+ """
24
+ logger.debug("Checking for fraud in the API...")
25
+
26
+ # Create an instance of the FraudPredictionInput model
27
+ fraud_input = FraudPredictionInput(
28
+ transaction_category=transaction.transaction_category,
29
+ transaction_amount=transaction.transaction_amount,
30
+ customer_job=transaction.customer_job,
31
+ customer_address_state=transaction.customer_address_state,
32
+ customer_address_city=transaction.customer_address_city,
33
+ customer_address_city_population=transaction.customer_address_city_population
34
+ )
35
+
36
+ # Send a POST request to the fraud detection API
37
+ response = requests.post(FRAUD_ML_PREDICTION_ENDPOINT,
38
+ json=fraud_input.model_dump(mode='json'),
39
+ headers={
40
+ 'Content-Type': 'application/json',
41
+ 'Authorization': FRAUD_ML_API_KEY,
42
+ })
43
+
44
+ if response.status_code == 200:
45
+ fraud_response = response.json()
46
+ logger.info(f"Fraud detection API response: {fraud_response}")
47
+
48
+ fraud_output = FraudPredictionOutput(**fraud_response)
49
+
50
+ return fraud_output
51
+ else:
52
+ logger.error(f"Failed to call fraud detection API: {response.status_code} - {response.text}")
53
+ return None
54
 
55
  def check_for_fraud(transaction: Transaction) -> bool:
56
  """