Phishing-Detection-System / scripts /predict_url_cnn.py
rb1337's picture
Upload 50 files
2cc7f91 verified
"""
CNN Phishing Detector - Interactive Demo
Test any URL with both character-level CNN models:
1. CNN URL — analyzes the URL string itself
2. CNN HTML — fetches the page and analyzes its HTML source
Usage:
python scripts/predict_url_cnn.py
"""
import sys
import json
import logging
import warnings
from pathlib import Path
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import numpy as np
from colorama import init, Fore, Style
init(autoreset=True)
warnings.filterwarnings('ignore')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%H:%M:%S',
)
logger = logging.getLogger('cnn_predictor')
# ---------------------------------------------------------------------------
# Project paths
# ---------------------------------------------------------------------------
PROJECT_ROOT = Path(__file__).resolve().parents[1] # src/
MODELS_DIR = PROJECT_ROOT / 'saved_models'
# URL CNN
URL_MODEL_PATH = MODELS_DIR / 'cnn_url_model.keras'
URL_VOCAB_PATH = MODELS_DIR / 'cnn_url_vocab.json'
# HTML CNN
HTML_MODEL_PATH = MODELS_DIR / 'cnn_html_model.keras'
HTML_VOCAB_PATH = MODELS_DIR / 'cnn_html_vocab.json'
class CNNPhishingDetector:
"""Detect phishing URLs using both character-level CNN models."""
def __init__(self):
self.url_model = None
self.html_model = None
self.url_vocab = None
self.html_vocab = None
self._load_url_model()
self._load_html_model()
# ── Loading ────────────────────────────────────────────────────
def _load_url_model(self):
"""Load URL CNN model and vocabulary."""
if not URL_VOCAB_PATH.exists() or not URL_MODEL_PATH.exists():
logger.warning("URL CNN model not found — skipping")
return
with open(URL_VOCAB_PATH, 'r') as f:
self.url_vocab = json.load(f)
import tensorflow as tf
self.url_model = tf.keras.models.load_model(str(URL_MODEL_PATH))
logger.info(f"✓ URL CNN loaded (vocab={self.url_vocab['vocab_size']}, "
f"max_len={self.url_vocab['max_len']})")
def _load_html_model(self):
"""Load HTML CNN model and vocabulary."""
if not HTML_VOCAB_PATH.exists() or not HTML_MODEL_PATH.exists():
logger.warning("HTML CNN model not found — skipping")
return
with open(HTML_VOCAB_PATH, 'r') as f:
self.html_vocab = json.load(f)
import tensorflow as tf
self.html_model = tf.keras.models.load_model(str(HTML_MODEL_PATH))
logger.info(f"✓ HTML CNN loaded (vocab={self.html_vocab['vocab_size']}, "
f"max_len={self.html_vocab['max_len']})")
# ── Encoding ───────────────────────────────────────────────────
def _encode_url(self, url: str) -> np.ndarray:
"""Encode a URL string for the URL CNN."""
char_to_idx = self.url_vocab['char_to_idx']
max_len = self.url_vocab['max_len']
encoded = [char_to_idx.get(c, 1) for c in url[:max_len]]
encoded += [0] * (max_len - len(encoded))
return np.array([encoded], dtype=np.int32)
def _encode_html(self, html: str) -> np.ndarray:
"""Encode an HTML string for the HTML CNN."""
char_to_idx = self.html_vocab['char_to_idx']
max_len = self.html_vocab['max_len']
encoded = [char_to_idx.get(c, 1) for c in html[:max_len]]
encoded += [0] * (max_len - len(encoded))
return np.array([encoded], dtype=np.int32)
# ── HTML fetching ──────────────────────────────────────────────
@staticmethod
def _fetch_html(url: str, timeout: int = 10) -> str | None:
"""Fetch HTML content from a URL. Returns None on failure."""
try:
import requests
headers = {
'User-Agent': ('Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
'AppleWebKit/537.36 (KHTML, like Gecko) '
'Chrome/120.0.0.0 Safari/537.36'),
}
resp = requests.get(url, headers=headers, timeout=timeout,
verify=False, allow_redirects=True)
resp.raise_for_status()
return resp.text
except Exception as e:
logger.warning(f" Could not fetch HTML: {e}")
return None
# ── Prediction ─────────────────────────────────────────────────
def predict_url(self, url: str, threshold: float = 0.5) -> dict | None:
"""Predict using the URL CNN model."""
if self.url_model is None:
return None
X = self._encode_url(url)
phishing_prob = float(self.url_model.predict(X, verbose=0)[0][0])
legitimate_prob = 1.0 - phishing_prob
is_phishing = phishing_prob >= threshold
return {
'model_name': 'CNN URL (Char-level)',
'prediction': 'PHISHING' if is_phishing else 'LEGITIMATE',
'prediction_code': int(is_phishing),
'confidence': (phishing_prob if is_phishing else legitimate_prob) * 100,
'phishing_probability': phishing_prob * 100,
'legitimate_probability': legitimate_prob * 100,
'threshold': threshold,
}
def predict_html(self, html: str, threshold: float = 0.5) -> dict | None:
"""Predict using the HTML CNN model."""
if self.html_model is None:
return None
X = self._encode_html(html)
phishing_prob = float(self.html_model.predict(X, verbose=0)[0][0])
legitimate_prob = 1.0 - phishing_prob
is_phishing = phishing_prob >= threshold
return {
'model_name': 'CNN HTML (Char-level)',
'prediction': 'PHISHING' if is_phishing else 'LEGITIMATE',
'prediction_code': int(is_phishing),
'confidence': (phishing_prob if is_phishing else legitimate_prob) * 100,
'phishing_probability': phishing_prob * 100,
'legitimate_probability': legitimate_prob * 100,
'threshold': threshold,
'html_length': len(html),
}
def predict_full(self, url: str, threshold: float = 0.5) -> dict:
"""
Run both CNN models on a URL.
Returns dict with url_result, html_result, and combined verdict.
"""
# URL CNN
url_result = self.predict_url(url, threshold)
# HTML CNN — fetch page first
html_result = None
html_content = None
if self.html_model is not None:
html_content = self._fetch_html(url)
if html_content and len(html_content) >= 100:
html_result = self.predict_html(html_content, threshold)
# Combined verdict
results = [r for r in [url_result, html_result] if r is not None]
if len(results) == 2:
avg_phish = (url_result['phishing_probability'] +
html_result['phishing_probability']) / 2
combined_is_phishing = avg_phish >= (threshold * 100)
combined = {
'prediction': 'PHISHING' if combined_is_phishing else 'LEGITIMATE',
'phishing_probability': avg_phish,
'legitimate_probability': 100 - avg_phish,
'confidence': avg_phish if combined_is_phishing else 100 - avg_phish,
'agree': url_result['prediction'] == html_result['prediction'],
}
elif len(results) == 1:
r = results[0]
combined = {
'prediction': r['prediction'],
'phishing_probability': r['phishing_probability'],
'legitimate_probability': r['legitimate_probability'],
'confidence': r['confidence'],
'agree': True,
}
else:
combined = None
return {
'url_result': url_result,
'html_result': html_result,
'html_fetched': html_content is not None,
'html_length': len(html_content) if html_content else 0,
'combined': combined,
}
# ── Pretty print ───────────────────────────────────────────────
def print_results(self, url: str, full: dict):
"""Print formatted prediction results from both models."""
print("\n" + "=" * 80)
print(f"{Fore.CYAN}{Style.BRIGHT}CNN PHISHING DETECTION RESULTS{Style.RESET_ALL}")
print("=" * 80)
print(f"\n{Fore.YELLOW}URL:{Style.RESET_ALL} {url}")
# ── URL CNN ──
url_r = full['url_result']
if url_r:
pred = url_r['prediction']
color = Fore.RED if pred == 'PHISHING' else Fore.GREEN
icon = "⚠️" if pred == 'PHISHING' else "✓"
print(f"\n{Style.BRIGHT}1. CNN URL (Character-level):{Style.RESET_ALL}")
print(f" {icon} Prediction: {color}{Style.BRIGHT}{pred}{Style.RESET_ALL}")
print(f" Confidence: {url_r['confidence']:.2f}%")
print(f" Phishing: {Fore.RED}{url_r['phishing_probability']:6.2f}%{Style.RESET_ALL}")
print(f" Legitimate: {Fore.GREEN}{url_r['legitimate_probability']:6.2f}%{Style.RESET_ALL}")
else:
print(f"\n{Style.BRIGHT}1. CNN URL:{Style.RESET_ALL} {Fore.YELLOW}Not available{Style.RESET_ALL}")
# ── HTML CNN ──
html_r = full['html_result']
if html_r:
pred = html_r['prediction']
color = Fore.RED if pred == 'PHISHING' else Fore.GREEN
icon = "⚠️" if pred == 'PHISHING' else "✓"
print(f"\n{Style.BRIGHT}2. CNN HTML (Character-level):{Style.RESET_ALL}")
print(f" {icon} Prediction: {color}{Style.BRIGHT}{pred}{Style.RESET_ALL}")
print(f" Confidence: {html_r['confidence']:.2f}%")
print(f" Phishing: {Fore.RED}{html_r['phishing_probability']:6.2f}%{Style.RESET_ALL}")
print(f" Legitimate: {Fore.GREEN}{html_r['legitimate_probability']:6.2f}%{Style.RESET_ALL}")
print(f" HTML length: {html_r['html_length']:,} chars")
elif full['html_fetched']:
print(f"\n{Style.BRIGHT}2. CNN HTML:{Style.RESET_ALL} "
f"{Fore.YELLOW}HTML too short for analysis{Style.RESET_ALL}")
else:
print(f"\n{Style.BRIGHT}2. CNN HTML:{Style.RESET_ALL} "
f"{Fore.YELLOW}Could not fetch page HTML{Style.RESET_ALL}")
# ── Combined verdict ──
combined = full['combined']
if combined:
pred = combined['prediction']
color = Fore.RED if pred == 'PHISHING' else Fore.GREEN
icon = "⚠️" if pred == 'PHISHING' else "✓"
agree_str = (f"{Fore.GREEN}YES{Style.RESET_ALL}" if combined['agree']
else f"{Fore.YELLOW}NO{Style.RESET_ALL}")
print(f"\n{'─' * 80}")
print(f"{Style.BRIGHT}COMBINED VERDICT:{Style.RESET_ALL}")
print(f" {icon} {color}{Style.BRIGHT}{pred}{Style.RESET_ALL} "
f"(confidence: {combined['confidence']:.2f}%)")
print(f" Phishing: {Fore.RED}{combined['phishing_probability']:6.2f}%{Style.RESET_ALL}")
print(f" Legitimate: {Fore.GREEN}{combined['legitimate_probability']:6.2f}%{Style.RESET_ALL}")
if url_r and html_r:
print(f" Models agree: {agree_str}")
print("\n" + "=" * 80 + "\n")
def main():
"""Interactive prediction loop."""
print(f"\n{Fore.CYAN}{Style.BRIGHT}╔══════════════════════════════════════════════════════════════╗")
print(f"║ CNN PHISHING DETECTOR - INTERACTIVE DEMO ║")
print(f"║ URL CNN + HTML CNN (Dual Analysis) ║")
print(f"╚══════════════════════════════════════════════════════════════╝{Style.RESET_ALL}\n")
print(f"{Fore.YELLOW}Loading CNN models...{Style.RESET_ALL}")
detector = CNNPhishingDetector()
available = []
if detector.url_model is not None:
available.append("URL CNN")
if detector.html_model is not None:
available.append("HTML CNN")
if not available:
print(f"{Fore.RED}No CNN models found! Train models first.{Style.RESET_ALL}")
sys.exit(1)
print(f"{Fore.GREEN}✓ Models loaded: {', '.join(available)}{Style.RESET_ALL}\n")
while True:
print(f"{Fore.CYAN}{'─' * 80}{Style.RESET_ALL}")
url = input(f"{Fore.YELLOW}Enter URL to test (or 'quit' to exit):{Style.RESET_ALL} ").strip()
if url.lower() in ('quit', 'exit', 'q'):
print(f"\n{Fore.GREEN}Goodbye!{Style.RESET_ALL}\n")
break
if not url:
print(f"{Fore.RED}Please enter a valid URL.{Style.RESET_ALL}\n")
continue
if not url.startswith(('http://', 'https://')):
url = 'http://' + url
try:
full = detector.predict_full(url)
detector.print_results(url, full)
except Exception as e:
print(f"\n{Fore.RED}Error: {e}{Style.RESET_ALL}\n")
logger.error(str(e))
if __name__ == '__main__':
main()