Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import joblib | |
| import pandas as pd | |
| from urllib.parse import urlparse | |
| import re | |
| import numpy as np | |
| import os | |
| from collections import Counter | |
| # ---------------------------- | |
| # Initialize FastAPI | |
| # ---------------------------- | |
| app = FastAPI(title="Malicious URL Detection API") | |
| # ---------------------------- | |
| # Load Model & Scaler | |
| # ---------------------------- | |
| try: | |
| model_path = "/code/model_repo/random_forest_model (1).joblib" | |
| scaler_path = "scaler.joblib" | |
| model = joblib.load(model_path) | |
| scaler = joblib.load(scaler_path) | |
| print("Model & scaler loaded successfully.") | |
| except Exception as e: | |
| raise RuntimeError(f"Error loading model or scaler: {str(e)}") | |
| # ---------------------------- | |
| # Request Body Schema | |
| # ---------------------------- | |
| class URLRequest(BaseModel): | |
| url: str | |
| # ---------------------------- | |
| # Helper Functions | |
| # ---------------------------- | |
| def count_digits(url): | |
| return sum(c.isdigit() for c in url) | |
| def count_special_chars(url): | |
| return len(re.findall(r'\W', url)) | |
| def count_letters(url): | |
| return sum(c.isalpha() for c in url) | |
| def has_https(url): | |
| return int(url.lower().startswith("https")) | |
| def count_subdomains(url): | |
| hostname = urlparse(url).hostname | |
| if hostname: | |
| return max(0, len(hostname.split(".")) - 2) | |
| return 0 | |
| def extract_tld(url): | |
| hostname = urlparse(url).hostname | |
| if hostname: | |
| return hostname.split(".")[-1] | |
| return None | |
| def path_depth(url): | |
| return urlparse(url).path.count("/") | |
| def shannon_entropy(s): | |
| prob = [n_x / len(s) for x, n_x in Counter(s).items()] | |
| return -sum(p * np.log2(p) for p in prob) | |
| def contains_ip(url): | |
| return int(bool(re.search(r"(\d{1,3}\.){3}\d{1,3}", url))) | |
| def keyword_flag(url, keyword): | |
| return int(keyword in url.lower()) | |
| # ---------------------------- | |
| # Feature Extraction | |
| # ---------------------------- | |
| FEATURE_COLUMNS = [ | |
| "url_length", "num_digits", "num_special", "num_letters", | |
| "has_https", "num_subdomains", "path_depth", "entropy", | |
| "contains_ip", "rare_tld_flag", | |
| "has_login", "has_secure", "has_account", "has_update", | |
| "has_bank", "has_verify", "has_confirm", "has_payment" | |
| ] | |
| KEYWORDS = ["login", "secure", "account", "update", "bank", "verify", "confirm", "payment"] | |
| def extract_features(df): | |
| data = pd.DataFrame() | |
| data["url"] = df["url"] | |
| data["url_length"] = df["url"].apply(len) | |
| data["num_digits"] = df["url"].apply(count_digits) | |
| data["num_special"] = df["url"].apply(count_special_chars) | |
| data["num_letters"] = df["url"].apply(count_letters) | |
| data["has_https"] = df["url"].apply(has_https) | |
| data["num_subdomains"] = df["url"].apply(count_subdomains) | |
| data["path_depth"] = df["url"].apply(path_depth) | |
| data["entropy"] = df["url"].apply(shannon_entropy) | |
| data["contains_ip"] = df["url"].apply(contains_ip) | |
| # TLD handling | |
| data["tld"] = df["url"].apply(extract_tld) | |
| data["tld_filled"] = data["tld"].fillna("unknown") | |
| # Rare TLD flag | |
| tld_freq = data["tld_filled"].value_counts(normalize=True) | |
| data["tld_freq"] = data["tld_filled"].map(tld_freq) | |
| data["rare_tld_flag"] = (data["tld_freq"] < 0.001).astype(int) | |
| data.drop(columns=["tld", "tld_filled", "tld_freq"], inplace=True) | |
| # Keyword flags | |
| for key in KEYWORDS: | |
| data[f"has_{key}"] = df["url"].apply(lambda x: keyword_flag(x, key)) | |
| return data | |
| # ---------------------------- | |
| # Home Route | |
| # ---------------------------- | |
| def home(): | |
| return {"message": "Malicious URL Detection API is running. Use POST /predict"} | |
| # ---------------------------- | |
| # Prediction Endpoint | |
| # ---------------------------- | |
| def predict(request: URLRequest): | |
| url = request.url.strip() | |
| if not url or len(url) < 3: | |
| raise HTTPException(status_code=400, detail="Invalid or empty URL provided.") | |
| df = pd.DataFrame({"url": [url]}) | |
| features = extract_features(df) | |
| # Reorder features to match training order | |
| features = features[FEATURE_COLUMNS] | |
| # Scale features | |
| scaled = scaler.transform(features) | |
| # Predict | |
| pred = model.predict(scaled)[0] | |
| return {"prediction": int(pred)} | |