raharjo's picture
Update main.py
154b9ea verified
#!/usr/bin/env python
# encoding: utf-8
from fastapi import FastAPI, Form, Depends, Request
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import pickle
import joblib
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from extraction_features import extract_features
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# model_file = open('logistic_regression_model.pkl', 'rb')
# model = pickle.load(model_file, encoding='bytes')
model = joblib.load('logistic_regression_model.pkl')
label_encoders = joblib.load('label_encoders.pkl')
# Columns used in the model
selected_columns = [
'URL', 'URLLength', 'Domain', 'DomainLength', 'TLD',
'CharContinuationRate', 'TLDLength', 'NoOfSubDomain',
'DegitRatioInURL', 'SpacialCharRatioInURL', 'IsHTTPS'
]
# Function to manage values for encoding (giving a new number for url which have never been seen)
def safe_transform(encoder, value):
if value in encoder.classes_:
return encoder.transform([value])[0]
else:
return -1 # Special code for unknown values
class Msg(BaseModel):
msg: str
class Req(BaseModel):
url: str
class Resp(BaseModel):
url: str
label: str
@app.get("/")
async def root():
return {"message": "Hello, Welcome to the final project from Albin Tardivel"}
def form_req(url: str = Form(...)):
return Req(url=str(url))
@app.get("/path")
async def demo_get():
return {"message": "This is /path endpoint, use a post request to transform the text to uppercase"}
@app.post("/path")
async def demo_post(inp: Msg):
return {"message": inp.msg.upper()}
@app.get("/path/{path_id}")
async def demo_get_path_id(path_id: int):
return {"message": f"This is /path/{path_id} endpoint, use post request to retrieve result"}
@app.get("/predict/{path_id}")
async def predict(path_id: int):
return {"message": f"This is /predict/{path_id} endpoint, use post request to retrieve result"}
@app.post("/predict")
async def predict(request: Request, requess: Req = Depends(form_req)):
'''
Predict if url is phishing or legitimate
and render the result to the html page
'''
url = requess.url
features = extract_features(str(url))
dataFrame_features = pd.DataFrame([features])
# Apply features encoding (convert everything into int64)
for column in ['URL', 'Domain', 'TLD']:
encoder = label_encoders[column]
dataFrame_features[column] = dataFrame_features[column].apply(lambda x: safe_transform(encoder, x))
data = dataFrame_features[selected_columns].values.reshape(1, -1)
# data = []
# data.append(str(features['URL']))
# data.extend(int(features['URLLength']))
# data.extend(str(features['Domain']))
# data.extend(int(features['DomainLength']))
# data.extend(str(features['TLD']))
# data.extend(float(features['CharContinuationRate']))
# data.extend(int(features['TLDLength']))
# data.extend(int(features['NoOfSubDomain']))
# data.extend(float(features['DegitRatioInURL']))
# data.extend(float(features['SpacialCharRatioInURL']))
# data.extend(int(features['IsHTTPS']))
# data.append(features['URL'])
# data.append(features['URLLength'])
# data.append(features['Domain'])
# data.append(features['DomainLength'])
# data.append(features['TLD'])
# data.append(features['CharContinuationRate'])
# data.append(features['TLDLength'])
# data.append(features['NoOfSubDomain'])
# data.append(features['DegitRatioInURL'])
# data.append(features['SpacialCharRatioInURL'])
# data.append(features['IsHTTPS'])
# Check number of dimensions before prediction
print("Nb dimensions before prediction:", data.shape)
print("Data sent to predict:", data)
print("Data types:", dataFrame_features[selected_columns].dtypes)
prediction = model.predict(data)
output = prediction[0]
output_text = str(output) #"Legitimate" if output == 1 else "Phishing"
# Render index.html with prediction results
json_compatible_resp_data = jsonable_encoder(Resp(url=requess.url, label=output_text))
return JSONResponse(content=json_compatible_resp_data)