"""
HTML Phishing Detection - Interactive Prediction
Predicts if HTML file/URL is phishing using trained model
"""
import sys
from pathlib import Path
import joblib
import pandas as pd
from colorama import init, Fore, Style
import requests
# Add project root to path
sys.path.append(str(Path(__file__).parent.parent))
from scripts.feature_extraction.html.html_feature_extractor import HTMLFeatureExtractor
from scripts.feature_extraction.html.feature_engineering import engineer_features
# Initialize colorama
init(autoreset=True)
class HTMLPhishingPredictor:
"""Predict phishing from HTML content using trained models."""
def __init__(self):
"""Initialize predictor with all trained models."""
models_dir = Path('saved_models')
# Load Random Forest model and its feature names
rf_model_path = models_dir / 'random_forest_html.joblib'
rf_features_path = models_dir / 'random_forest_html_feature_names.joblib'
if rf_model_path.exists():
print(f"Loading Random Forest model: {rf_model_path}")
self.rf_model = joblib.load(rf_model_path)
self.has_rf = True
# Load RF feature names
if rf_features_path.exists():
self.rf_feature_names = joblib.load(rf_features_path)
print(f"Loaded {len(self.rf_feature_names)} Random Forest feature names")
else:
self.rf_feature_names = None
else:
print(f"{Fore.YELLOW}Random Forest model not found{Style.RESET_ALL}")
self.rf_model = None
self.has_rf = False
self.rf_feature_names = None
# Load XGBoost model and its feature names
xgb_model_path = models_dir / 'xgboost_html.joblib'
xgb_features_path = models_dir / 'xgboost_html_feature_names.joblib'
if xgb_model_path.exists():
print(f"Loading XGBoost model: {xgb_model_path}")
self.xgb_model = joblib.load(xgb_model_path)
self.has_xgb = True
# Load XGBoost feature names
if xgb_features_path.exists():
self.xgb_feature_names = joblib.load(xgb_features_path)
print(f"Loaded {len(self.xgb_feature_names)} XGBoost feature names")
else:
self.xgb_feature_names = None
else:
print(f"{Fore.YELLOW}XGBoost model not found{Style.RESET_ALL}")
self.xgb_model = None
self.has_xgb = False
self.xgb_feature_names = None
if not self.has_rf and not self.has_xgb:
raise FileNotFoundError("No trained models found! Train models first.")
self.extractor = HTMLFeatureExtractor()
def predict_from_file(self, html_file_path):
"""Predict from HTML file."""
# Read HTML content
with open(html_file_path, 'r', encoding='utf-8', errors='ignore') as f:
html_content = f.read()
return self.predict_from_html(html_content, source=str(html_file_path))
def predict_from_url(self, url):
"""Download HTML from URL and predict."""
print(f"\nDownloading HTML from: {url}")
try:
# Download HTML
response = requests.get(url, timeout=10, verify=False)
html_content = response.text
return self.predict_from_html(html_content, source=url)
except Exception as e:
print(f"{Fore.RED}Error downloading URL: {e}")
return None
def predict_from_html(self, html_content, source=""):
"""Predict from HTML content using all available models."""
# Extract raw features
features = self.extractor.extract_features(html_content)
# Apply feature engineering (same as training)
raw_df = pd.DataFrame([features])
eng_df = engineer_features(raw_df)
# Get predictions from all models
predictions = {}
if self.has_rf:
if self.rf_feature_names:
feature_values = [eng_df[fn].iloc[0] if fn in eng_df.columns
else features.get(fn, 0)
for fn in self.rf_feature_names]
X_rf = pd.DataFrame([dict(zip(self.rf_feature_names, feature_values))])
else:
X_rf = eng_df
rf_pred = self.rf_model.predict(X_rf)[0] # type: ignore
rf_proba = self.rf_model.predict_proba(X_rf)[0] # type: ignore
predictions['Random Forest'] = {
'prediction': rf_pred,
'probability': rf_proba
}
if self.has_xgb:
if self.xgb_feature_names:
feature_values = [eng_df[fn].iloc[0] if fn in eng_df.columns
else features.get(fn, 0)
for fn in self.xgb_feature_names]
X_xgb = pd.DataFrame([dict(zip(self.xgb_feature_names, feature_values))])
else:
X_xgb = eng_df
xgb_pred = self.xgb_model.predict(X_xgb)[0] # type: ignore
xgb_proba = self.xgb_model.predict_proba(X_xgb)[0] # type: ignore
predictions['XGBoost'] = {
'prediction': xgb_pred,
'probability': xgb_proba
}
# Ensemble prediction (average probabilities)
if len(predictions) > 1:
avg_proba = sum([p['probability'] for p in predictions.values()]) / len(predictions)
ensemble_pred = 1 if avg_proba[1] > 0.5 else 0 # type: ignore
predictions['Ensemble'] = {
'prediction': ensemble_pred,
'probability': avg_proba
}
# Display results
self._display_prediction(predictions, features, source)
return {
'predictions': predictions,
'features': features
}
def _display_prediction(self, predictions, features, source):
"""Display prediction results with colors."""
print("\n" + "="*80)
if source:
print(f"Source: {source}")
print("="*80)
# Get ensemble or single prediction for final verdict
if 'Ensemble' in predictions:
final_pred = predictions['Ensemble']['prediction']
final_proba = predictions['Ensemble']['probability']
else:
# Use the only available model
model_name = list(predictions.keys())[0]
final_pred = predictions[model_name]['prediction']
final_proba = predictions[model_name]['probability']
# Final Verdict
if final_pred == 1:
print(f"\n{Fore.RED}{'⚠ PHISHING DETECTED ⚠':^80}")
print(f"{Fore.RED}Confidence: {final_proba[1]*100:.2f}%")
else:
print(f"\n{Fore.GREEN}{'✓ LEGITIMATE WEBSITE ✓':^80}")
print(f"{Fore.GREEN}Confidence: {final_proba[0]*100:.2f}%")
# Model breakdown
print("\n" + "-"*80)
print("Model Predictions:")
print("-"*80)
for model_name, result in predictions.items():
pred = result['prediction']
proba = result['probability']
pred_text = 'PHISHING' if pred == 1 else 'LEGITIMATE'
color = Fore.RED if pred == 1 else Fore.GREEN
icon = "⚠" if pred == 1 else "✓"
print(f" {icon} {model_name:15s}: {color}{pred_text:12s}{Style.RESET_ALL} "
f"(Legit: {proba[0]*100:5.1f}%, Phish: {proba[1]*100:5.1f}%)")
# Show key features
print("\n" + "-"*80)
print("Key HTML Features:")
print("-"*80)
important_features = [
('num_forms', 'Number of forms'),
('num_password_fields', 'Password fields'),
('num_external_links', 'External links'),
('num_scripts', 'Scripts'),
('num_urgency_keywords', 'Urgency keywords'),
('num_brand_mentions', 'Brand mentions'),
('has_meta_refresh', 'Meta refresh redirect'),
('num_iframes', 'Iframes'),
]
for feat, desc in important_features:
if feat in features:
value = features[feat]
print(f" {desc:25s}: {value}")
print("="*80)
def interactive_mode():
"""Interactive mode for testing multiple inputs."""
print("\n" + "="*80)
print(f"{Fore.CYAN}{'HTML PHISHING DETECTOR - INTERACTIVE MODE':^80}")
print("="*80)
# Load predictor
try:
predictor = HTMLPhishingPredictor()
except Exception as e:
print(f"{Fore.RED}Error loading model: {e}")
print("\nTrain a model first using:")
print(" python models/html_enhanced/random_forest_html.py")
return
print("\nCommands:")
print(" file - Analyze HTML file")
print(" url - Download and analyze URL")
print(" quit - Exit")
print("-"*80)
while True:
try:
user_input = input(f"\n{Fore.CYAN}Enter command: {Style.RESET_ALL}").strip()
if not user_input:
continue
if user_input.lower() in ['quit', 'exit', 'q']:
print("\nGoodbye!")
break
parts = user_input.split(maxsplit=1)
command = parts[0].lower()
if command == 'file' and len(parts) == 2:
file_path = parts[1].strip()
if Path(file_path).exists():
predictor.predict_from_file(file_path)
else:
print(f"{Fore.RED}File not found: {file_path}")
elif command == 'url' and len(parts) == 2:
url = parts[1].strip()
predictor.predict_from_url(url)
else:
print(f"{Fore.YELLOW}Invalid command. Use: file or url ")
except KeyboardInterrupt:
print("\n\nGoodbye!")
break
except Exception as e:
print(f"{Fore.RED}Error: {e}")
def main():
"""Main function."""
if len(sys.argv) > 1:
# Command line mode
predictor = HTMLPhishingPredictor()
arg = sys.argv[1]
if Path(arg).exists():
# File path
predictor.predict_from_file(arg)
elif arg.startswith('http'):
# URL
predictor.predict_from_url(arg)
else:
print(f"Invalid input: {arg}")
print("\nUsage:")
print(" python scripts/predict_html.py ")
print(" python scripts/predict_html.py ")
print(" python scripts/predict_html.py (interactive mode)")
else:
# Interactive mode
interactive_mode()
if __name__ == '__main__':
main()