#!/usr/bin/env python # encoding: utf-8 from fastapi import FastAPI, Form, Depends, Request from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import pickle from extraction_features import extract_features app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) model_file = open('logistic_regression_model.pkl', 'rb') model = pickle.load(model_file, encoding='bytes') class Msg(BaseModel): msg: str class Req(BaseModel): url: str class Resp(BaseModel): url: str label: int @app.get("/") async def root(): return {"message": "Hello, Welcome to the final project from Albin Tardivel"} def form_req(url: str = Form(...)): return Req(url=str(url)) @app.get("/path") async def demo_get(): return {"message": "This is /path endpoint, use a post request to transform the text to uppercase"} @app.post("/path") async def demo_post(inp: Msg): return {"message": inp.msg.upper()} @app.get("/path/{path_id}") async def demo_get_path_id(path_id: int): return {"message": f"This is /path/{path_id} endpoint, use post request to retrieve result"} @app.get("/predict/{path_id}") async def predict(path_id: int): return {"message": f"This is /predict/{path_id} endpoint, use post request to retrieve result"} @app.post("/predict") async def predict(request: Request, requess: Req = Depends(form_req)): ''' Predict if url is phishing or legitimate and render the result to the html page ''' url = requess.url features = extract_features(str(url)) data = [] data.append(str(features.URL)) data.extend(int(features.URLLength)) data.extend(str(features.Domain)) data.extend(int(features.DomainLength)) data.extend(str(features.TLD)) data.extend(float(features.CharContinuationRate)) data.extend(int(features.TLDLength)) data.extend(int(features.NoOfSubDomain)) data.extend(float(features.DegitRatioInURL)) data.extend(float(features.SpacialCharRatioInURL)) data.extend(int(features.IsHTTPS)) prediction = model.predict([data]) output = prediction[0] output_text = "Legitimate" if output == 1 else "Phishing" # Render index.html with prediction results json_compatible_resp_data = jsonable_encoder(Resp(url=requess.url, label=output_text)) return JSONResponse(content=json_compatible_resp_data)