Spaces:
Runtime error
Runtime error
File size: 4,666 Bytes
5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 80f46f9 5d7f8a1 4009697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import numpy as np
from urllib.parse import urlparse
from tld import get_tld
import re
# FastAPI instance
app = FastAPI()
# Load your trained model
model = joblib.load("rf_model.pkl") # Ensure you save your RandomForest model as rf_model.pkl
# Define the request body
class URLRequest(BaseModel):
url: str
# Feature extraction functions
def having_ip_address(url):
match = re.search(
'(([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.'
'([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\/)|' # IPv4
'((0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\/)' # IPv4 in hexadecimal
'(?:[a-fA-F0-9]{1,4}:){7}[a-fA-F0-9]{1,4}', url) # Ipv6
return 1 if match else 0
def abnormal_url(url):
hostname = urlparse(url).hostname
match = re.search(str(hostname), url)
return 1 if match else 0
def count_dot(url):
return url.count('.')
def count_www(url):
return url.count('www')
def count_atrate(url):
return url.count('@')
def no_of_dir(url):
return urlparse(url).path.count('/')
def no_of_embed(url):
return urlparse(url).path.count('//')
def shortening_service(url):
match = re.search('bit\.ly|goo\.gl|shorte\.st|go2l\.ink|x\.co|ow\.ly|t\.co|tinyurl|tr\.im|is\.gd|cli\.gs|'
'yfrog\.com|migre\.me|ff\.im|tiny\.cc|url4\.eu|twit\.ac|su\.pr|twurl\.nl|snipurl\.com|'
'short\.to|BudURL\.com|ping\.fm|post\.ly|Just\.as|bkite\.com|snipr\.com|fic\.kr|loopt\.us|'
'doiop\.com|short\.ie|kl\.am|wp\.me|rubyurl\.com|om\.ly|to\.ly|bit\.do|t\.co|lnkd\.in|'
'db\.tt|qr\.ae|adf\.ly|goo\.gl|bitly\.com|cur\.lv|tinyurl\.com|ow\.ly|bit\.ly|ity\.im|'
'q\.gs|is\.gd|po\.st|bc\.vc|twitthis\.com|u\.to|j\.mp|buzurl\.com|cutt\.us|u\.bb|yourls\.org|'
'x\.co|prettylinkpro\.com|scrnch\.me|filoops\.info|vzturl\.com|qr\.net|1url\.com|tweez\.me|v\.gd|'
'tr\.im|link\.zip\.net', url)
return 1 if match else 0
def count_https(url):
return url.count('https')
def count_http(url):
return url.count('http')
def count_per(url):
return url.count('%')
def count_ques(url):
return url.count('?')
def count_hyphen(url):
return url.count('-')
def count_equal(url):
return url.count('=')
def url_length(url):
return len(str(url))
def hostname_length(url):
return len(urlparse(url).netloc)
def suspicious_words(url):
match = re.search('PayPal|login|signin|bank|account|update|free|lucky|service|bonus|ebayisapi|webscr', url)
return 1 if match else 0
def digit_count(url):
return sum(1 for i in url if i.isnumeric())
def letter_count(url):
return sum(1 for i in url if i.isalpha())
def fd_length(url):
urlpath = urlparse(url).path
try:
return len(urlpath.split('/')[1])
except:
return 0
def tld_length(tld):
try:
return len(tld)
except:
return -1
# Extract features from URL
def main(url):
status = []
status.append(having_ip_address(url))
status.append(abnormal_url(url))
status.append(count_dot(url))
status.append(count_www(url))
status.append(count_atrate(url))
status.append(no_of_dir(url))
status.append(no_of_embed(url))
status.append(shortening_service(url))
status.append(count_https(url))
status.append(count_http(url))
status.append(count_per(url))
status.append(count_ques(url))
status.append(count_hyphen(url))
status.append(count_equal(url))
status.append(url_length(url))
status.append(hostname_length(url))
status.append(suspicious_words(url))
status.append(digit_count(url))
status.append(letter_count(url))
status.append(fd_length(url))
tld = get_tld(url, fail_silently=True)
status.append(tld_length(tld))
return status
def get_prediction_from_url(test_url):
features_test = main(test_url)
features_test = np.array(features_test).reshape((1, -1))
pred = model.predict(features_test)
if int(pred[0]) == 0:
return "SAFE"
elif int(pred[0]) == 1:
return "DEFACEMENT"
elif int(pred[0]) == 2:
return "PHISHING"
elif int(pred[0]) == 3:
return "MALWARE"
# Define prediction endpoint
@app.post("/predict/")
def predict(request: URLRequest):
prediction = get_prediction_from_url(request.url)
return {"prediction": prediction}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |