V2LinkDetection / app.py
ZealPyae's picture
Update app.py
4009697 verified
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import numpy as np
from urllib.parse import urlparse
from tld import get_tld
import re
# FastAPI instance
app = FastAPI()
# Load your trained model
model = joblib.load("rf_model.pkl") # Ensure you save your RandomForest model as rf_model.pkl
# Define the request body
class URLRequest(BaseModel):
url: str
# Feature extraction functions
def having_ip_address(url):
match = re.search(
'(([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.'
'([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\/)|' # IPv4
'((0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\/)' # IPv4 in hexadecimal
'(?:[a-fA-F0-9]{1,4}:){7}[a-fA-F0-9]{1,4}', url) # Ipv6
return 1 if match else 0
def abnormal_url(url):
hostname = urlparse(url).hostname
match = re.search(str(hostname), url)
return 1 if match else 0
def count_dot(url):
return url.count('.')
def count_www(url):
return url.count('www')
def count_atrate(url):
return url.count('@')
def no_of_dir(url):
return urlparse(url).path.count('/')
def no_of_embed(url):
return urlparse(url).path.count('//')
def shortening_service(url):
match = re.search('bit\.ly|goo\.gl|shorte\.st|go2l\.ink|x\.co|ow\.ly|t\.co|tinyurl|tr\.im|is\.gd|cli\.gs|'
'yfrog\.com|migre\.me|ff\.im|tiny\.cc|url4\.eu|twit\.ac|su\.pr|twurl\.nl|snipurl\.com|'
'short\.to|BudURL\.com|ping\.fm|post\.ly|Just\.as|bkite\.com|snipr\.com|fic\.kr|loopt\.us|'
'doiop\.com|short\.ie|kl\.am|wp\.me|rubyurl\.com|om\.ly|to\.ly|bit\.do|t\.co|lnkd\.in|'
'db\.tt|qr\.ae|adf\.ly|goo\.gl|bitly\.com|cur\.lv|tinyurl\.com|ow\.ly|bit\.ly|ity\.im|'
'q\.gs|is\.gd|po\.st|bc\.vc|twitthis\.com|u\.to|j\.mp|buzurl\.com|cutt\.us|u\.bb|yourls\.org|'
'x\.co|prettylinkpro\.com|scrnch\.me|filoops\.info|vzturl\.com|qr\.net|1url\.com|tweez\.me|v\.gd|'
'tr\.im|link\.zip\.net', url)
return 1 if match else 0
def count_https(url):
return url.count('https')
def count_http(url):
return url.count('http')
def count_per(url):
return url.count('%')
def count_ques(url):
return url.count('?')
def count_hyphen(url):
return url.count('-')
def count_equal(url):
return url.count('=')
def url_length(url):
return len(str(url))
def hostname_length(url):
return len(urlparse(url).netloc)
def suspicious_words(url):
match = re.search('PayPal|login|signin|bank|account|update|free|lucky|service|bonus|ebayisapi|webscr', url)
return 1 if match else 0
def digit_count(url):
return sum(1 for i in url if i.isnumeric())
def letter_count(url):
return sum(1 for i in url if i.isalpha())
def fd_length(url):
urlpath = urlparse(url).path
try:
return len(urlpath.split('/')[1])
except:
return 0
def tld_length(tld):
try:
return len(tld)
except:
return -1
# Extract features from URL
def main(url):
status = []
status.append(having_ip_address(url))
status.append(abnormal_url(url))
status.append(count_dot(url))
status.append(count_www(url))
status.append(count_atrate(url))
status.append(no_of_dir(url))
status.append(no_of_embed(url))
status.append(shortening_service(url))
status.append(count_https(url))
status.append(count_http(url))
status.append(count_per(url))
status.append(count_ques(url))
status.append(count_hyphen(url))
status.append(count_equal(url))
status.append(url_length(url))
status.append(hostname_length(url))
status.append(suspicious_words(url))
status.append(digit_count(url))
status.append(letter_count(url))
status.append(fd_length(url))
tld = get_tld(url, fail_silently=True)
status.append(tld_length(tld))
return status
def get_prediction_from_url(test_url):
features_test = main(test_url)
features_test = np.array(features_test).reshape((1, -1))
pred = model.predict(features_test)
if int(pred[0]) == 0:
return "SAFE"
elif int(pred[0]) == 1:
return "DEFACEMENT"
elif int(pred[0]) == 2:
return "PHISHING"
elif int(pred[0]) == 3:
return "MALWARE"
# Define prediction endpoint
@app.post("/predict/")
def predict(request: URLRequest):
prediction = get_prediction_from_url(request.url)
return {"prediction": prediction}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)