#!/usr/bin/env python # encoding: utf-8 from fastapi import FastAPI, Form, Depends, Request from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import pickle import joblib import pandas as pd from sklearn.preprocessing import LabelEncoder from extraction_features import extract_features app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # model_file = open('logistic_regression_model.pkl', 'rb') # model = pickle.load(model_file, encoding='bytes') model = joblib.load('logistic_regression_model.pkl') label_encoders = joblib.load('label_encoders.pkl') # Columns used in the model selected_columns = [ 'URL', 'URLLength', 'Domain', 'DomainLength', 'TLD', 'CharContinuationRate', 'TLDLength', 'NoOfSubDomain', 'DegitRatioInURL', 'SpacialCharRatioInURL', 'IsHTTPS' ] # Function to manage values for encoding (giving a new number for url which have never been seen) def safe_transform(encoder, value): if value in encoder.classes_: return encoder.transform([value])[0] else: return -1 # Special code for unknown values class Msg(BaseModel): msg: str class Req(BaseModel): url: str class Resp(BaseModel): url: str label: str @app.get("/") async def root(): return {"message": "Hello, Welcome to the final project from Albin Tardivel"} def form_req(url: str = Form(...)): return Req(url=str(url)) @app.get("/path") async def demo_get(): return {"message": "This is /path endpoint, use a post request to transform the text to uppercase"} @app.post("/path") async def demo_post(inp: Msg): return {"message": inp.msg.upper()} @app.get("/path/{path_id}") async def demo_get_path_id(path_id: int): return {"message": f"This is /path/{path_id} endpoint, use post request to retrieve result"} @app.get("/predict/{path_id}") async def predict(path_id: int): return {"message": f"This is /predict/{path_id} endpoint, use post request to retrieve result"} @app.post("/predict") async def predict(request: Request, requess: Req = Depends(form_req)): ''' Predict if url is phishing or legitimate and render the result to the html page ''' url = requess.url features = extract_features(str(url)) dataFrame_features = pd.DataFrame([features]) # Apply features encoding (convert everything into int64) for column in ['URL', 'Domain', 'TLD']: encoder = label_encoders[column] dataFrame_features[column] = dataFrame_features[column].apply(lambda x: safe_transform(encoder, x)) data = dataFrame_features[selected_columns].values.reshape(1, -1) # data = [] # data.append(str(features['URL'])) # data.extend(int(features['URLLength'])) # data.extend(str(features['Domain'])) # data.extend(int(features['DomainLength'])) # data.extend(str(features['TLD'])) # data.extend(float(features['CharContinuationRate'])) # data.extend(int(features['TLDLength'])) # data.extend(int(features['NoOfSubDomain'])) # data.extend(float(features['DegitRatioInURL'])) # data.extend(float(features['SpacialCharRatioInURL'])) # data.extend(int(features['IsHTTPS'])) # data.append(features['URL']) # data.append(features['URLLength']) # data.append(features['Domain']) # data.append(features['DomainLength']) # data.append(features['TLD']) # data.append(features['CharContinuationRate']) # data.append(features['TLDLength']) # data.append(features['NoOfSubDomain']) # data.append(features['DegitRatioInURL']) # data.append(features['SpacialCharRatioInURL']) # data.append(features['IsHTTPS']) # Check number of dimensions before prediction print("Nb dimensions before prediction:", data.shape) print("Data sent to predict:", data) print("Data types:", dataFrame_features[selected_columns].dtypes) prediction = model.predict(data) output = prediction[0] output_text = output #"Legitimate" if output == 1 else "Phishing" # Render index.html with prediction results json_compatible_resp_data = jsonable_encoder(Resp(url=requess.url, label=output_text)) return JSONResponse(content=json_compatible_resp_data)