Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import joblib | |
| import numpy as np | |
| from urllib.parse import urlparse | |
| from tld import get_tld | |
| import re | |
| # FastAPI instance | |
| app = FastAPI() | |
| # Load your trained model | |
| model = joblib.load("rf_model.pkl") # Ensure you save your RandomForest model as rf_model.pkl | |
| # Define the request body | |
| class URLRequest(BaseModel): | |
| url: str | |
| # Feature extraction functions | |
| def having_ip_address(url): | |
| match = re.search( | |
| '(([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.' | |
| '([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\/)|' # IPv4 | |
| '((0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\/)' # IPv4 in hexadecimal | |
| '(?:[a-fA-F0-9]{1,4}:){7}[a-fA-F0-9]{1,4}', url) # Ipv6 | |
| return 1 if match else 0 | |
| def abnormal_url(url): | |
| hostname = urlparse(url).hostname | |
| match = re.search(str(hostname), url) | |
| return 1 if match else 0 | |
| def count_dot(url): | |
| return url.count('.') | |
| def count_www(url): | |
| return url.count('www') | |
| def count_atrate(url): | |
| return url.count('@') | |
| def no_of_dir(url): | |
| return urlparse(url).path.count('/') | |
| def no_of_embed(url): | |
| return urlparse(url).path.count('//') | |
| def shortening_service(url): | |
| match = re.search('bit\.ly|goo\.gl|shorte\.st|go2l\.ink|x\.co|ow\.ly|t\.co|tinyurl|tr\.im|is\.gd|cli\.gs|' | |
| 'yfrog\.com|migre\.me|ff\.im|tiny\.cc|url4\.eu|twit\.ac|su\.pr|twurl\.nl|snipurl\.com|' | |
| 'short\.to|BudURL\.com|ping\.fm|post\.ly|Just\.as|bkite\.com|snipr\.com|fic\.kr|loopt\.us|' | |
| 'doiop\.com|short\.ie|kl\.am|wp\.me|rubyurl\.com|om\.ly|to\.ly|bit\.do|t\.co|lnkd\.in|' | |
| 'db\.tt|qr\.ae|adf\.ly|goo\.gl|bitly\.com|cur\.lv|tinyurl\.com|ow\.ly|bit\.ly|ity\.im|' | |
| 'q\.gs|is\.gd|po\.st|bc\.vc|twitthis\.com|u\.to|j\.mp|buzurl\.com|cutt\.us|u\.bb|yourls\.org|' | |
| 'x\.co|prettylinkpro\.com|scrnch\.me|filoops\.info|vzturl\.com|qr\.net|1url\.com|tweez\.me|v\.gd|' | |
| 'tr\.im|link\.zip\.net', url) | |
| return 1 if match else 0 | |
| def count_https(url): | |
| return url.count('https') | |
| def count_http(url): | |
| return url.count('http') | |
| def count_per(url): | |
| return url.count('%') | |
| def count_ques(url): | |
| return url.count('?') | |
| def count_hyphen(url): | |
| return url.count('-') | |
| def count_equal(url): | |
| return url.count('=') | |
| def url_length(url): | |
| return len(str(url)) | |
| def hostname_length(url): | |
| return len(urlparse(url).netloc) | |
| def suspicious_words(url): | |
| match = re.search('PayPal|login|signin|bank|account|update|free|lucky|service|bonus|ebayisapi|webscr', url) | |
| return 1 if match else 0 | |
| def digit_count(url): | |
| return sum(1 for i in url if i.isnumeric()) | |
| def letter_count(url): | |
| return sum(1 for i in url if i.isalpha()) | |
| def fd_length(url): | |
| urlpath = urlparse(url).path | |
| try: | |
| return len(urlpath.split('/')[1]) | |
| except: | |
| return 0 | |
| def tld_length(tld): | |
| try: | |
| return len(tld) | |
| except: | |
| return -1 | |
| # Extract features from URL | |
| def main(url): | |
| status = [] | |
| status.append(having_ip_address(url)) | |
| status.append(abnormal_url(url)) | |
| status.append(count_dot(url)) | |
| status.append(count_www(url)) | |
| status.append(count_atrate(url)) | |
| status.append(no_of_dir(url)) | |
| status.append(no_of_embed(url)) | |
| status.append(shortening_service(url)) | |
| status.append(count_https(url)) | |
| status.append(count_http(url)) | |
| status.append(count_per(url)) | |
| status.append(count_ques(url)) | |
| status.append(count_hyphen(url)) | |
| status.append(count_equal(url)) | |
| status.append(url_length(url)) | |
| status.append(hostname_length(url)) | |
| status.append(suspicious_words(url)) | |
| status.append(digit_count(url)) | |
| status.append(letter_count(url)) | |
| status.append(fd_length(url)) | |
| tld = get_tld(url, fail_silently=True) | |
| status.append(tld_length(tld)) | |
| return status | |
| def get_prediction_from_url(test_url): | |
| features_test = main(test_url) | |
| features_test = np.array(features_test).reshape((1, -1)) | |
| pred = model.predict(features_test) | |
| if int(pred[0]) == 0: | |
| return "SAFE" | |
| elif int(pred[0]) == 1: | |
| return "DEFACEMENT" | |
| elif int(pred[0]) == 2: | |
| return "PHISHING" | |
| elif int(pred[0]) == 3: | |
| return "MALWARE" | |
| # Define prediction endpoint | |
| def predict(request: URLRequest): | |
| prediction = get_prediction_from_url(request.url) | |
| return {"prediction": prediction} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |