Spaces:
Sleeping
Sleeping
File size: 4,416 Bytes
5350d90 4fe11fa 324e0e0 02c4993 4cd2ca2 5350d90 4fe11fa 324e0e0 5350d90 5fc419a 3628144 5fc419a 5350d90 47c1734 5350d90 4cd2ca2 5350d90 02c4993 5350d90 02c4993 324e0e0 5fc419a f5e9881 5fc419a 324e0e0 9372d92 324e0e0 73511e4 324e0e0 02c4993 9372d92 15559a4 9372d92 e9ff521 5350d90 154b9ea 5350d90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
#!/usr/bin/env python
# encoding: utf-8
from fastapi import FastAPI, Form, Depends, Request
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import pickle
import joblib
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from extraction_features import extract_features
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# model_file = open('logistic_regression_model.pkl', 'rb')
# model = pickle.load(model_file, encoding='bytes')
model = joblib.load('logistic_regression_model.pkl')
label_encoders = joblib.load('label_encoders.pkl')
# Columns used in the model
selected_columns = [
'URL', 'URLLength', 'Domain', 'DomainLength', 'TLD',
'CharContinuationRate', 'TLDLength', 'NoOfSubDomain',
'DegitRatioInURL', 'SpacialCharRatioInURL', 'IsHTTPS'
]
# Function to manage values for encoding (giving a new number for url which have never been seen)
def safe_transform(encoder, value):
if value in encoder.classes_:
return encoder.transform([value])[0]
else:
return -1 # Special code for unknown values
class Msg(BaseModel):
msg: str
class Req(BaseModel):
url: str
class Resp(BaseModel):
url: str
label: str
@app.get("/")
async def root():
return {"message": "Hello, Welcome to the final project from Albin Tardivel"}
def form_req(url: str = Form(...)):
return Req(url=str(url))
@app.get("/path")
async def demo_get():
return {"message": "This is /path endpoint, use a post request to transform the text to uppercase"}
@app.post("/path")
async def demo_post(inp: Msg):
return {"message": inp.msg.upper()}
@app.get("/path/{path_id}")
async def demo_get_path_id(path_id: int):
return {"message": f"This is /path/{path_id} endpoint, use post request to retrieve result"}
@app.get("/predict/{path_id}")
async def predict(path_id: int):
return {"message": f"This is /predict/{path_id} endpoint, use post request to retrieve result"}
@app.post("/predict")
async def predict(request: Request, requess: Req = Depends(form_req)):
'''
Predict if url is phishing or legitimate
and render the result to the html page
'''
url = requess.url
features = extract_features(str(url))
dataFrame_features = pd.DataFrame([features])
# Apply features encoding (convert everything into int64)
for column in ['URL', 'Domain', 'TLD']:
encoder = label_encoders[column]
dataFrame_features[column] = dataFrame_features[column].apply(lambda x: safe_transform(encoder, x))
data = dataFrame_features[selected_columns].values.reshape(1, -1)
# data = []
# data.append(str(features['URL']))
# data.extend(int(features['URLLength']))
# data.extend(str(features['Domain']))
# data.extend(int(features['DomainLength']))
# data.extend(str(features['TLD']))
# data.extend(float(features['CharContinuationRate']))
# data.extend(int(features['TLDLength']))
# data.extend(int(features['NoOfSubDomain']))
# data.extend(float(features['DegitRatioInURL']))
# data.extend(float(features['SpacialCharRatioInURL']))
# data.extend(int(features['IsHTTPS']))
# data.append(features['URL'])
# data.append(features['URLLength'])
# data.append(features['Domain'])
# data.append(features['DomainLength'])
# data.append(features['TLD'])
# data.append(features['CharContinuationRate'])
# data.append(features['TLDLength'])
# data.append(features['NoOfSubDomain'])
# data.append(features['DegitRatioInURL'])
# data.append(features['SpacialCharRatioInURL'])
# data.append(features['IsHTTPS'])
# Check number of dimensions before prediction
print("Nb dimensions before prediction:", data.shape)
print("Data sent to predict:", data)
print("Data types:", dataFrame_features[selected_columns].dtypes)
prediction = model.predict(data)
output = prediction[0]
output_text = str(output) #"Legitimate" if output == 1 else "Phishing"
# Render index.html with prediction results
json_compatible_resp_data = jsonable_encoder(Resp(url=requess.url, label=output_text))
return JSONResponse(content=json_compatible_resp_data) |