from fastapi import FastAPI from pydantic import BaseModel import joblib import numpy as np from urllib.parse import urlparse from tld import get_tld import re # FastAPI instance app = FastAPI() # Load your trained model model = joblib.load("rf_model.pkl") # Ensure you save your RandomForest model as rf_model.pkl # Define the request body class URLRequest(BaseModel): url: str # Feature extraction functions def having_ip_address(url): match = re.search( '(([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.' '([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\/)|' # IPv4 '((0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\/)' # IPv4 in hexadecimal '(?:[a-fA-F0-9]{1,4}:){7}[a-fA-F0-9]{1,4}', url) # Ipv6 return 1 if match else 0 def abnormal_url(url): hostname = urlparse(url).hostname match = re.search(str(hostname), url) return 1 if match else 0 def count_dot(url): return url.count('.') def count_www(url): return url.count('www') def count_atrate(url): return url.count('@') def no_of_dir(url): return urlparse(url).path.count('/') def no_of_embed(url): return urlparse(url).path.count('//') def shortening_service(url): match = re.search('bit\.ly|goo\.gl|shorte\.st|go2l\.ink|x\.co|ow\.ly|t\.co|tinyurl|tr\.im|is\.gd|cli\.gs|' 'yfrog\.com|migre\.me|ff\.im|tiny\.cc|url4\.eu|twit\.ac|su\.pr|twurl\.nl|snipurl\.com|' 'short\.to|BudURL\.com|ping\.fm|post\.ly|Just\.as|bkite\.com|snipr\.com|fic\.kr|loopt\.us|' 'doiop\.com|short\.ie|kl\.am|wp\.me|rubyurl\.com|om\.ly|to\.ly|bit\.do|t\.co|lnkd\.in|' 'db\.tt|qr\.ae|adf\.ly|goo\.gl|bitly\.com|cur\.lv|tinyurl\.com|ow\.ly|bit\.ly|ity\.im|' 'q\.gs|is\.gd|po\.st|bc\.vc|twitthis\.com|u\.to|j\.mp|buzurl\.com|cutt\.us|u\.bb|yourls\.org|' 'x\.co|prettylinkpro\.com|scrnch\.me|filoops\.info|vzturl\.com|qr\.net|1url\.com|tweez\.me|v\.gd|' 'tr\.im|link\.zip\.net', url) return 1 if match else 0 def count_https(url): return url.count('https') def count_http(url): return url.count('http') def count_per(url): return url.count('%') def count_ques(url): return url.count('?') def count_hyphen(url): return url.count('-') def count_equal(url): return url.count('=') def url_length(url): return len(str(url)) def hostname_length(url): return len(urlparse(url).netloc) def suspicious_words(url): match = re.search('PayPal|login|signin|bank|account|update|free|lucky|service|bonus|ebayisapi|webscr', url) return 1 if match else 0 def digit_count(url): return sum(1 for i in url if i.isnumeric()) def letter_count(url): return sum(1 for i in url if i.isalpha()) def fd_length(url): urlpath = urlparse(url).path try: return len(urlpath.split('/')[1]) except: return 0 def tld_length(tld): try: return len(tld) except: return -1 # Extract features from URL def main(url): status = [] status.append(having_ip_address(url)) status.append(abnormal_url(url)) status.append(count_dot(url)) status.append(count_www(url)) status.append(count_atrate(url)) status.append(no_of_dir(url)) status.append(no_of_embed(url)) status.append(shortening_service(url)) status.append(count_https(url)) status.append(count_http(url)) status.append(count_per(url)) status.append(count_ques(url)) status.append(count_hyphen(url)) status.append(count_equal(url)) status.append(url_length(url)) status.append(hostname_length(url)) status.append(suspicious_words(url)) status.append(digit_count(url)) status.append(letter_count(url)) status.append(fd_length(url)) tld = get_tld(url, fail_silently=True) status.append(tld_length(tld)) return status def get_prediction_from_url(test_url): features_test = main(test_url) features_test = np.array(features_test).reshape((1, -1)) pred = model.predict(features_test) if int(pred[0]) == 0: return "SAFE" elif int(pred[0]) == 1: return "DEFACEMENT" elif int(pred[0]) == 2: return "PHISHING" elif int(pred[0]) == 3: return "MALWARE" # Define prediction endpoint @app.post("/predict/") def predict(request: URLRequest): prediction = get_prediction_from_url(request.url) return {"prediction": prediction} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)