#!/usr/bin/env python # encoding: utf-8 from fastapi import FastAPI, Form, Depends, Request from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import pickle import joblib import pandas as pd from sklearn.preprocessing import LabelEncoder from extraction_features import extract_features app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # model_file = open('logistic_regression_model.pkl', 'rb') # model = pickle.load(model_file, encoding='bytes') model = joblib.load('logistic_regression_model.pkl') label_encoders = joblib.load('label_encoders.pkl') # Columns used in the model selected_columns = [ 'URLLength', 'Domain', 'DomainLength', 'TLD', 'CharContinuationRate', 'TLDLength', 'NoOfSubDomain', 'DegitRatioInURL', 'SpacialCharRatioInURL', 'IsHTTPS' ] # Function to manage values for encoding (giving a new number for url which have never been seen) def safe_transform(encoder, value): if value in encoder.classes_: return encoder.transform([value])[0] else: return -1 class Msg(BaseModel): msg: str class Req(BaseModel): url: str class Resp(BaseModel): url: str label: str @app.get("/") async def root(): return {"message": "Hello, Welcome to the final project from Albin Tardivel"} def form_req(url: str = Form(...)): return Req(url=str(url)) @app.get("/path") async def demo_get(): return {"message": "This is /path endpoint, use a post request to transform the text to uppercase"} @app.post("/path") async def demo_post(inp: Msg): return {"message": inp.msg.upper()} @app.get("/path/{path_id}") async def demo_get_path_id(path_id: int): return {"message": f"This is /path/{path_id} endpoint, use post request to retrieve result"} @app.get("/predict/{path_id}") async def predict(path_id: int): return {"message": f"This is /predict/{path_id} endpoint, use post request to retrieve result"} @app.post("/predict") async def predict(request: Request, requess: Req = Depends(form_req)): ''' Predict if url is phishing or legitimate and render the result to the html page ''' url = requess.url features = extract_features(str(url)) dataFrame_features = pd.DataFrame([features]) # Apply features encoding (convert everything into int64) for column in ['Domain', 'TLD']: encoder = label_encoders[column] dataFrame_features[column] = dataFrame_features[column].apply(lambda x: safe_transform(encoder, x)) data = dataFrame_features[selected_columns].values.reshape(1, -1) prediction_proba = model.predict_proba(data)[:, 1] threshold = 0.9 print("prediction_proba:", prediction_proba) output = 1 if prediction_proba >= threshold else 0 output_text = "Legitimate" if output == 1 else "Phishing" # Render index.html with prediction results json_compatible_resp_data = jsonable_encoder(Resp(url=requess.url, label=output_text)) return JSONResponse(content=json_compatible_resp_data)