Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # encoding: utf-8 | |
| from fastapi import FastAPI, Form, Depends, Request | |
| from fastapi.encoders import jsonable_encoder | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import pickle | |
| import joblib | |
| import pandas as pd | |
| from sklearn.preprocessing import LabelEncoder | |
| from extraction_features import extract_features | |
| app = FastAPI() | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # model_file = open('logistic_regression_model.pkl', 'rb') | |
| # model = pickle.load(model_file, encoding='bytes') | |
| model = joblib.load('logistic_regression_model.pkl') | |
| label_encoders = joblib.load('label_encoders.pkl') | |
| # Columns used in the model | |
| selected_columns = [ | |
| 'URLLength', 'Domain', 'DomainLength', 'TLD', | |
| 'CharContinuationRate', 'TLDLength', 'NoOfSubDomain', | |
| 'DegitRatioInURL', 'SpacialCharRatioInURL', 'IsHTTPS' | |
| ] | |
| # Function to manage values for encoding (giving a new number for url which have never been seen) | |
| def safe_transform(encoder, value): | |
| if value in encoder.classes_: | |
| return encoder.transform([value])[0] | |
| else: | |
| return -1 | |
| class Msg(BaseModel): | |
| msg: str | |
| class Req(BaseModel): | |
| url: str | |
| class Resp(BaseModel): | |
| url: str | |
| label: str | |
| async def root(): | |
| return {"message": "Hello, Welcome to the final project from Albin Tardivel"} | |
| def form_req(url: str = Form(...)): | |
| return Req(url=str(url)) | |
| async def demo_get(): | |
| return {"message": "This is /path endpoint, use a post request to transform the text to uppercase"} | |
| async def demo_post(inp: Msg): | |
| return {"message": inp.msg.upper()} | |
| async def demo_get_path_id(path_id: int): | |
| return {"message": f"This is /path/{path_id} endpoint, use post request to retrieve result"} | |
| async def predict(path_id: int): | |
| return {"message": f"This is /predict/{path_id} endpoint, use post request to retrieve result"} | |
| async def predict(request: Request, requess: Req = Depends(form_req)): | |
| ''' | |
| Predict if url is phishing or legitimate | |
| and render the result to the html page | |
| ''' | |
| url = requess.url | |
| features = extract_features(str(url)) | |
| dataFrame_features = pd.DataFrame([features]) | |
| # Apply features encoding (convert everything into int64) | |
| for column in ['Domain', 'TLD']: | |
| encoder = label_encoders[column] | |
| dataFrame_features[column] = dataFrame_features[column].apply(lambda x: safe_transform(encoder, x)) | |
| data = dataFrame_features[selected_columns].values.reshape(1, -1) | |
| prediction_proba = model.predict_proba(data)[:, 1] | |
| threshold = 0.9 | |
| print("prediction_proba:", prediction_proba) | |
| output = 1 if prediction_proba >= threshold else 0 | |
| output_text = "Legitimate" if output == 1 else "Phishing" | |
| # Render index.html with prediction results | |
| json_compatible_resp_data = jsonable_encoder(Resp(url=requess.url, label=output_text)) | |
| return JSONResponse(content=json_compatible_resp_data) |