url_prediction / app.py
big2undey's picture
Update app.py
d1525e9 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import pandas as pd
from urllib.parse import urlparse
import re
import numpy as np
import os
from collections import Counter
# ----------------------------
# Initialize FastAPI
# ----------------------------
app = FastAPI(title="Malicious URL Detection API")
# ----------------------------
# Load Model & Scaler
# ----------------------------
try:
model_path = "/code/model_repo/random_forest_model (1).joblib"
scaler_path = "scaler.joblib"
model = joblib.load(model_path)
scaler = joblib.load(scaler_path)
print("Model & scaler loaded successfully.")
except Exception as e:
raise RuntimeError(f"Error loading model or scaler: {str(e)}")
# ----------------------------
# Request Body Schema
# ----------------------------
class URLRequest(BaseModel):
url: str
# ----------------------------
# Helper Functions
# ----------------------------
def count_digits(url):
return sum(c.isdigit() for c in url)
def count_special_chars(url):
return len(re.findall(r'\W', url))
def count_letters(url):
return sum(c.isalpha() for c in url)
def has_https(url):
return int(url.lower().startswith("https"))
def count_subdomains(url):
hostname = urlparse(url).hostname
if hostname:
return max(0, len(hostname.split(".")) - 2)
return 0
def extract_tld(url):
hostname = urlparse(url).hostname
if hostname:
return hostname.split(".")[-1]
return None
def path_depth(url):
return urlparse(url).path.count("/")
def shannon_entropy(s):
prob = [n_x / len(s) for x, n_x in Counter(s).items()]
return -sum(p * np.log2(p) for p in prob)
def contains_ip(url):
return int(bool(re.search(r"(\d{1,3}\.){3}\d{1,3}", url)))
def keyword_flag(url, keyword):
return int(keyword in url.lower())
# ----------------------------
# Feature Extraction
# ----------------------------
FEATURE_COLUMNS = [
"url_length", "num_digits", "num_special", "num_letters",
"has_https", "num_subdomains", "path_depth", "entropy",
"contains_ip", "rare_tld_flag",
"has_login", "has_secure", "has_account", "has_update",
"has_bank", "has_verify", "has_confirm", "has_payment"
]
KEYWORDS = ["login", "secure", "account", "update", "bank", "verify", "confirm", "payment"]
def extract_features(df):
data = pd.DataFrame()
data["url"] = df["url"]
data["url_length"] = df["url"].apply(len)
data["num_digits"] = df["url"].apply(count_digits)
data["num_special"] = df["url"].apply(count_special_chars)
data["num_letters"] = df["url"].apply(count_letters)
data["has_https"] = df["url"].apply(has_https)
data["num_subdomains"] = df["url"].apply(count_subdomains)
data["path_depth"] = df["url"].apply(path_depth)
data["entropy"] = df["url"].apply(shannon_entropy)
data["contains_ip"] = df["url"].apply(contains_ip)
# TLD handling
data["tld"] = df["url"].apply(extract_tld)
data["tld_filled"] = data["tld"].fillna("unknown")
# Rare TLD flag
tld_freq = data["tld_filled"].value_counts(normalize=True)
data["tld_freq"] = data["tld_filled"].map(tld_freq)
data["rare_tld_flag"] = (data["tld_freq"] < 0.001).astype(int)
data.drop(columns=["tld", "tld_filled", "tld_freq"], inplace=True)
# Keyword flags
for key in KEYWORDS:
data[f"has_{key}"] = df["url"].apply(lambda x: keyword_flag(x, key))
return data
# ----------------------------
# Home Route
# ----------------------------
@app.get("/")
def home():
return {"message": "Malicious URL Detection API is running. Use POST /predict"}
# ----------------------------
# Prediction Endpoint
# ----------------------------
@app.post("/predict")
def predict(request: URLRequest):
url = request.url.strip()
if not url or len(url) < 3:
raise HTTPException(status_code=400, detail="Invalid or empty URL provided.")
df = pd.DataFrame({"url": [url]})
features = extract_features(df)
# Reorder features to match training order
features = features[FEATURE_COLUMNS]
# Scale features
scaled = scaler.transform(features)
# Predict
pred = model.predict(scaled)[0]
return {"prediction": int(pred)}