XSS_Payload_Detector / inference_bert_url.py
kd7979148's picture
Update inference_bert_url.py
4f77521 verified
# -*- coding: utf-8 -*-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from urllib.parse import (
urlparse,
parse_qs,
unquote
)
#################################################
# model path
#################################################
model_path = "xss_detect_trained"
#################################################
# URL existence
#################################################
def is_url(text):
return text.startswith("http://") or text.startswith("https://")
#################################################
# URL에서 parameter value
#################################################
def extract_url_payload(url):
try:
parsed = urlparse(url)
# query parameter
params = parse_qs(parsed.query)
extracted = []
for key, values in params.items():
for value in values:
# URL decode
decoded = unquote(value)
extracted.append(decoded)
# use path when no parameter
if not extracted:
return parsed.path
# combine multiple parameters
return " ".join(extracted)
except:
return url
#################################################
# check
#################################################
def contains_suspicious_code(text):
suspicious_patterns = [
# HTML / JS
"<",
">",
"script",
"javascript:",
"onerror",
"onclick",
"onload",
"iframe",
"svg",
# JS
"eval(",
"alert(",
"prompt(",
"confirm(",
"document.cookie",
"document.domain",
"window.location",
# bypass
"constructor",
"fromcharcode",
"\\x",
"%3c",
"%3e",
"&#",
"base64",
"atob(",
#
"srcdoc",
"data:text/html",
"vbscript:",
"expression("
]
text_lower = text.lower()
for pattern in suspicious_patterns:
if pattern in text_lower:
return True
return False
#################################################
# load
#################################################
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
device = torch.device("cpu")
model.to(device)
model.eval()
#################################################
# label
#################################################
labels = {
0: "NORMAL",
1: "XSS"
}
#################################################
# test
#################################################
print("\n Test Start (type exit to end)\n")
while True:
text = input("input: ")
if text.lower() == "exit":
break
#################################################
# basic
#################################################
target_text = text
#################################################
# URL
#################################################
if is_url(text):
target_text = extract_url_payload(text)
print(f"[extracted parameter]: {target_text}")
#################################################
# NORMAL when no suspicious code
#################################################
if not contains_suspicious_code(target_text):
print("result: NORMAL")
print("Reliability: heuristic\n")
continue
#################################################
# tokenize
#################################################
MAX_INPUT_LENGTH = 2000
if len(target_text) > MAX_INPUT_LENGTH:
print("Input Length Exceeded\n")
continue
inputs = tokenizer(
target_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
).to(device)
#################################################
#
#################################################
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
confidence, pred = torch.max(probs, dim=1)
pred = pred.item()
confidence = confidence.item()
label = labels[pred]
#################################################
# result
#################################################
print(f"result: {label}")
print(f"Reliability: {confidence:.4f}\n")