ZealPyae commited on
Commit
5d7f8a1
·
verified ·
1 Parent(s): ad7d3f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -31
app.py CHANGED
@@ -1,32 +1,110 @@
1
  from flask import Flask, request, jsonify
 
 
 
2
  import numpy as np
3
- import joblib # to save and load the trained model
4
- import re
5
  from urllib.parse import urlparse
6
  from tld import get_tld
7
- from sklearn.ensemble import RandomForestClassifier
8
 
9
- app = Flask(__name__)
 
10
 
11
  # Load your trained model
12
- model = joblib.load('random_forest_model.pkl') # save your trained model using joblib
13
 
14
- # Define your feature extraction functions here...
 
 
15
 
 
16
  def having_ip_address(url):
17
- # Your implementation
18
- pass
 
 
 
 
19
 
20
  def abnormal_url(url):
21
- # Your implementation
22
- pass
 
23
 
24
  def count_dot(url):
25
- # Your implementation
26
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Define other functions similarly...
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def main(url):
31
  status = []
32
  status.append(having_ip_address(url))
@@ -53,23 +131,24 @@ def main(url):
53
  status.append(tld_length(tld))
54
  return status
55
 
56
- @app.route('/predict', methods=['POST'])
57
- def predict():
58
- data = request.json
59
- url = data.get('url')
60
- features = np.array(main(url)).reshape(1, -1)
61
- prediction = model.predict(features)
62
-
63
- if int(prediction[0]) == 0:
64
- result = "SAFE"
65
- elif int(prediction[0]) == 1:
66
- result = "DEFACEMENT"
67
- elif int(prediction[0]) == 2:
68
- result = "PHISHING"
69
- elif int(prediction[0]) == 3:
70
- result = "MALWARE"
71
-
72
- return jsonify({"prediction": result})
73
 
 
 
 
 
 
 
74
  if __name__ == '__main__':
75
- app.run(host='0.0.0.0', port=8000)
 
1
  from flask import Flask, request, jsonify
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ import joblib
5
  import numpy as np
 
 
6
  from urllib.parse import urlparse
7
  from tld import get_tld
8
+ import re
9
 
10
+ # FastAPI instance
11
+ app = FastAPI()
12
 
13
  # Load your trained model
14
+ model = joblib.load("rf_model.pkl") # Ensure you save your RandomForest model as rf_model.pkl
15
 
16
+ # Define the request body
17
+ class URLRequest(BaseModel):
18
+ url: str
19
 
20
+ # Feature extraction functions
21
  def having_ip_address(url):
22
+ match = re.search(
23
+ '(([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.'
24
+ '([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\/)|' # IPv4
25
+ '((0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\/)' # IPv4 in hexadecimal
26
+ '(?:[a-fA-F0-9]{1,4}:){7}[a-fA-F0-9]{1,4}', url) # Ipv6
27
+ return 1 if match else 0
28
 
29
  def abnormal_url(url):
30
+ hostname = urlparse(url).hostname
31
+ match = re.search(str(hostname), url)
32
+ return 1 if match else 0
33
 
34
  def count_dot(url):
35
+ return url.count('.')
36
+
37
+ def count_www(url):
38
+ return url.count('www')
39
+
40
+ def count_atrate(url):
41
+ return url.count('@')
42
+
43
+ def no_of_dir(url):
44
+ return urlparse(url).path.count('/')
45
+
46
+ def no_of_embed(url):
47
+ return urlparse(url).path.count('//')
48
+
49
+ def shortening_service(url):
50
+ match = re.search('bit\.ly|goo\.gl|shorte\.st|go2l\.ink|x\.co|ow\.ly|t\.co|tinyurl|tr\.im|is\.gd|cli\.gs|'
51
+ 'yfrog\.com|migre\.me|ff\.im|tiny\.cc|url4\.eu|twit\.ac|su\.pr|twurl\.nl|snipurl\.com|'
52
+ 'short\.to|BudURL\.com|ping\.fm|post\.ly|Just\.as|bkite\.com|snipr\.com|fic\.kr|loopt\.us|'
53
+ 'doiop\.com|short\.ie|kl\.am|wp\.me|rubyurl\.com|om\.ly|to\.ly|bit\.do|t\.co|lnkd\.in|'
54
+ 'db\.tt|qr\.ae|adf\.ly|goo\.gl|bitly\.com|cur\.lv|tinyurl\.com|ow\.ly|bit\.ly|ity\.im|'
55
+ 'q\.gs|is\.gd|po\.st|bc\.vc|twitthis\.com|u\.to|j\.mp|buzurl\.com|cutt\.us|u\.bb|yourls\.org|'
56
+ 'x\.co|prettylinkpro\.com|scrnch\.me|filoops\.info|vzturl\.com|qr\.net|1url\.com|tweez\.me|v\.gd|'
57
+ 'tr\.im|link\.zip\.net', url)
58
+ return 1 if match else 0
59
+
60
+ def count_https(url):
61
+ return url.count('https')
62
+
63
+ def count_http(url):
64
+ return url.count('http')
65
+
66
+ def count_per(url):
67
+ return url.count('%')
68
+
69
+ def count_ques(url):
70
+ return url.count('?')
71
 
72
+ def count_hyphen(url):
73
+ return url.count('-')
74
 
75
+ def count_equal(url):
76
+ return url.count('=')
77
+
78
+ def url_length(url):
79
+ return len(str(url))
80
+
81
+ def hostname_length(url):
82
+ return len(urlparse(url).netloc)
83
+
84
+ def suspicious_words(url):
85
+ match = re.search('PayPal|login|signin|bank|account|update|free|lucky|service|bonus|ebayisapi|webscr', url)
86
+ return 1 if match else 0
87
+
88
+ def digit_count(url):
89
+ return sum(1 for i in url if i.isnumeric())
90
+
91
+ def letter_count(url):
92
+ return sum(1 for i in url if i.isalpha())
93
+
94
+ def fd_length(url):
95
+ urlpath = urlparse(url).path
96
+ try:
97
+ return len(urlpath.split('/')[1])
98
+ except:
99
+ return 0
100
+
101
+ def tld_length(tld):
102
+ try:
103
+ return len(tld)
104
+ except:
105
+ return -1
106
+
107
+ # Extract features from URL
108
  def main(url):
109
  status = []
110
  status.append(having_ip_address(url))
 
131
  status.append(tld_length(tld))
132
  return status
133
 
134
+ def get_prediction_from_url(test_url):
135
+ features_test = main(test_url)
136
+ features_test = np.array(features_test).reshape((1, -1))
137
+ pred = model.predict(features_test)
138
+ if int(pred[0]) == 0:
139
+ return "SAFE"
140
+ elif int(pred[0]) == 1:
141
+ return "DEFACEMENT"
142
+ elif int(pred[0]) == 2:
143
+ return "PHISHING"
144
+ elif int(pred[0]) == 3:
145
+ return "MALWARE"
 
 
 
 
 
146
 
147
+ # Define prediction endpoint
148
+ @app.post("/predict/")
149
+ def predict(request: URLRequest):
150
+ prediction = get_prediction_from_url(request.url)
151
+ return {"prediction": prediction}
152
+
153
  if __name__ == '__main__':
154
+ app.run(host='0.0.0.0', port=8000)