LazyBoss commited on
Commit
669c133
·
verified ·
1 Parent(s): 4e97de9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +47 -32
main.py CHANGED
@@ -1,56 +1,71 @@
1
  import joblib
2
  from pydantic import BaseModel
3
- from fastapi import FastAPI
4
  import uvicorn
5
  import logging
 
 
 
6
 
7
- logging.basicConfig(level = logging.INFO)
8
  # 1. Load the trained model
9
- model = joblib.load('frauddetection.pkl')
 
 
 
 
 
 
 
10
 
11
  # 2. Define the input data schema using Pydantic BaseModel
12
  class InputData(BaseModel):
13
- Year:int
14
- Month:int
15
- UseChip:int
16
- Amount:int
17
- MerchantName:int
18
- MerchantCity:int
19
- MerchantState:int
20
- mcc:int
21
- # Add the rest of the input features (feature4, feature5, ..., feature12)
22
 
23
  # 3. Create a FastAPI app
24
  app = FastAPI()
 
25
  @app.get('/')
26
  def welcome():
27
  return {"Welcome": "This is the home page of the API"}
28
 
29
-
30
  # 4. Define the prediction route
31
  @app.post('/predict/')
32
  async def predict(data: InputData):
 
 
 
33
  # Convert the input data to a dictionary
34
  input_data = data.dict()
35
 
36
  # Extract the input features from the dictionary
37
- feature1 = input_data['Year']
38
- feature2=input_data['Month']
39
- feature3=input_data['UseChip']
40
- feature4=input_data['Amount']
41
- feature5=input_data['MerchantName']
42
- feature6=input_data['MerchantCity']
43
- feature7=input_data['MerchantState']
44
- feature8=input_data['mcc']
45
- # Extract the rest of the input features (feature4, feature5, ..., feature12)
46
-
47
- # Perform the prediction using the loaded model
48
- prediction = model.predict([[feature1, feature2, feature3,feature4,feature5,feature6,feature7,feature8]]) # Replace ... with the rest of the features
49
- # Convert the prediction to a string (or any other format you prefer)
50
- result = "Fraud" if prediction[0] == 1 else "Not a Fraud"
51
- # result = prediction
52
- return {"prediction": result}
53
- # 4. Run the API with uvicorn
54
- # Will run on http://127.0.0.1:8000
 
 
 
 
55
  if __name__ == '__main__':
56
- uvicorn.run(app, port=8080)
 
1
  import joblib
2
  from pydantic import BaseModel
3
+ from fastapi import FastAPI, HTTPException
4
  import uvicorn
5
  import logging
6
+ import os
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
 
 
10
  # 1. Load the trained model
11
+ try:
12
+ if os.path.exists('frauddetection.pkl'):
13
+ model = joblib.load('frauddetection.pkl')
14
+ else:
15
+ raise FileNotFoundError("Model file 'frauddetection.pkl' not found.")
16
+ except Exception as e:
17
+ logging.error(f"Error loading model: {e}")
18
+ model = None
19
 
20
  # 2. Define the input data schema using Pydantic BaseModel
21
  class InputData(BaseModel):
22
+ Year: int
23
+ Month: int
24
+ UseChip: int
25
+ Amount: int
26
+ MerchantName: int
27
+ MerchantCity: int
28
+ MerchantState: int
29
+ mcc: int
 
30
 
31
  # 3. Create a FastAPI app
32
  app = FastAPI()
33
+
34
  @app.get('/')
35
  def welcome():
36
  return {"Welcome": "This is the home page of the API"}
37
 
 
38
  # 4. Define the prediction route
39
  @app.post('/predict/')
40
  async def predict(data: InputData):
41
+ if model is None:
42
+ raise HTTPException(status_code=500, detail="Model not loaded properly.")
43
+
44
  # Convert the input data to a dictionary
45
  input_data = data.dict()
46
 
47
  # Extract the input features from the dictionary
48
+ feature_list = [
49
+ input_data['Year'],
50
+ input_data['Month'],
51
+ input_data['UseChip'],
52
+ input_data['Amount'],
53
+ input_data['MerchantName'],
54
+ input_data['MerchantCity'],
55
+ input_data['MerchantState'],
56
+ input_data['mcc']
57
+ ]
58
+
59
+ try:
60
+ # Perform the prediction using the loaded model
61
+ prediction = model.predict([feature_list]) # Ensure model expects the correct number of features
62
+ result = "Fraud" if prediction[0] == 1 else "Not a Fraud"
63
+ return {"prediction": result}
64
+ except Exception as e:
65
+ logging.error(f"Prediction error: {e}")
66
+ raise HTTPException(status_code=500, detail="An error occurred during prediction.")
67
+
68
+ # 5. Run the API with uvicorn
69
+ # Will run on http://127.0.0.1:8080
70
  if __name__ == '__main__':
71
+ uvicorn.run(app, host="127.0.0.1", port=8080)