""" HTML Phishing Detection - Interactive Prediction Predicts if HTML file/URL is phishing using trained model """ import sys from pathlib import Path import joblib import pandas as pd from colorama import init, Fore, Style import requests # Add project root to path sys.path.append(str(Path(__file__).parent.parent)) from scripts.feature_extraction.html.html_feature_extractor import HTMLFeatureExtractor from scripts.feature_extraction.html.feature_engineering import engineer_features # Initialize colorama init(autoreset=True) class HTMLPhishingPredictor: """Predict phishing from HTML content using trained models.""" def __init__(self): """Initialize predictor with all trained models.""" models_dir = Path('saved_models') # Load Random Forest model and its feature names rf_model_path = models_dir / 'random_forest_html.joblib' rf_features_path = models_dir / 'random_forest_html_feature_names.joblib' if rf_model_path.exists(): print(f"Loading Random Forest model: {rf_model_path}") self.rf_model = joblib.load(rf_model_path) self.has_rf = True # Load RF feature names if rf_features_path.exists(): self.rf_feature_names = joblib.load(rf_features_path) print(f"Loaded {len(self.rf_feature_names)} Random Forest feature names") else: self.rf_feature_names = None else: print(f"{Fore.YELLOW}Random Forest model not found{Style.RESET_ALL}") self.rf_model = None self.has_rf = False self.rf_feature_names = None # Load XGBoost model and its feature names xgb_model_path = models_dir / 'xgboost_html.joblib' xgb_features_path = models_dir / 'xgboost_html_feature_names.joblib' if xgb_model_path.exists(): print(f"Loading XGBoost model: {xgb_model_path}") self.xgb_model = joblib.load(xgb_model_path) self.has_xgb = True # Load XGBoost feature names if xgb_features_path.exists(): self.xgb_feature_names = joblib.load(xgb_features_path) print(f"Loaded {len(self.xgb_feature_names)} XGBoost feature names") else: self.xgb_feature_names = None else: print(f"{Fore.YELLOW}XGBoost model not found{Style.RESET_ALL}") self.xgb_model = None self.has_xgb = False self.xgb_feature_names = None if not self.has_rf and not self.has_xgb: raise FileNotFoundError("No trained models found! Train models first.") self.extractor = HTMLFeatureExtractor() def predict_from_file(self, html_file_path): """Predict from HTML file.""" # Read HTML content with open(html_file_path, 'r', encoding='utf-8', errors='ignore') as f: html_content = f.read() return self.predict_from_html(html_content, source=str(html_file_path)) def predict_from_url(self, url): """Download HTML from URL and predict.""" print(f"\nDownloading HTML from: {url}") try: # Download HTML response = requests.get(url, timeout=10, verify=False) html_content = response.text return self.predict_from_html(html_content, source=url) except Exception as e: print(f"{Fore.RED}Error downloading URL: {e}") return None def predict_from_html(self, html_content, source=""): """Predict from HTML content using all available models.""" # Extract raw features features = self.extractor.extract_features(html_content) # Apply feature engineering (same as training) raw_df = pd.DataFrame([features]) eng_df = engineer_features(raw_df) # Get predictions from all models predictions = {} if self.has_rf: if self.rf_feature_names: feature_values = [eng_df[fn].iloc[0] if fn in eng_df.columns else features.get(fn, 0) for fn in self.rf_feature_names] X_rf = pd.DataFrame([dict(zip(self.rf_feature_names, feature_values))]) else: X_rf = eng_df rf_pred = self.rf_model.predict(X_rf)[0] # type: ignore rf_proba = self.rf_model.predict_proba(X_rf)[0] # type: ignore predictions['Random Forest'] = { 'prediction': rf_pred, 'probability': rf_proba } if self.has_xgb: if self.xgb_feature_names: feature_values = [eng_df[fn].iloc[0] if fn in eng_df.columns else features.get(fn, 0) for fn in self.xgb_feature_names] X_xgb = pd.DataFrame([dict(zip(self.xgb_feature_names, feature_values))]) else: X_xgb = eng_df xgb_pred = self.xgb_model.predict(X_xgb)[0] # type: ignore xgb_proba = self.xgb_model.predict_proba(X_xgb)[0] # type: ignore predictions['XGBoost'] = { 'prediction': xgb_pred, 'probability': xgb_proba } # Ensemble prediction (average probabilities) if len(predictions) > 1: avg_proba = sum([p['probability'] for p in predictions.values()]) / len(predictions) ensemble_pred = 1 if avg_proba[1] > 0.5 else 0 # type: ignore predictions['Ensemble'] = { 'prediction': ensemble_pred, 'probability': avg_proba } # Display results self._display_prediction(predictions, features, source) return { 'predictions': predictions, 'features': features } def _display_prediction(self, predictions, features, source): """Display prediction results with colors.""" print("\n" + "="*80) if source: print(f"Source: {source}") print("="*80) # Get ensemble or single prediction for final verdict if 'Ensemble' in predictions: final_pred = predictions['Ensemble']['prediction'] final_proba = predictions['Ensemble']['probability'] else: # Use the only available model model_name = list(predictions.keys())[0] final_pred = predictions[model_name]['prediction'] final_proba = predictions[model_name]['probability'] # Final Verdict if final_pred == 1: print(f"\n{Fore.RED}{'⚠ PHISHING DETECTED ⚠':^80}") print(f"{Fore.RED}Confidence: {final_proba[1]*100:.2f}%") else: print(f"\n{Fore.GREEN}{'✓ LEGITIMATE WEBSITE ✓':^80}") print(f"{Fore.GREEN}Confidence: {final_proba[0]*100:.2f}%") # Model breakdown print("\n" + "-"*80) print("Model Predictions:") print("-"*80) for model_name, result in predictions.items(): pred = result['prediction'] proba = result['probability'] pred_text = 'PHISHING' if pred == 1 else 'LEGITIMATE' color = Fore.RED if pred == 1 else Fore.GREEN icon = "⚠" if pred == 1 else "✓" print(f" {icon} {model_name:15s}: {color}{pred_text:12s}{Style.RESET_ALL} " f"(Legit: {proba[0]*100:5.1f}%, Phish: {proba[1]*100:5.1f}%)") # Show key features print("\n" + "-"*80) print("Key HTML Features:") print("-"*80) important_features = [ ('num_forms', 'Number of forms'), ('num_password_fields', 'Password fields'), ('num_external_links', 'External links'), ('num_scripts', 'Scripts'), ('num_urgency_keywords', 'Urgency keywords'), ('num_brand_mentions', 'Brand mentions'), ('has_meta_refresh', 'Meta refresh redirect'), ('num_iframes', 'Iframes'), ] for feat, desc in important_features: if feat in features: value = features[feat] print(f" {desc:25s}: {value}") print("="*80) def interactive_mode(): """Interactive mode for testing multiple inputs.""" print("\n" + "="*80) print(f"{Fore.CYAN}{'HTML PHISHING DETECTOR - INTERACTIVE MODE':^80}") print("="*80) # Load predictor try: predictor = HTMLPhishingPredictor() except Exception as e: print(f"{Fore.RED}Error loading model: {e}") print("\nTrain a model first using:") print(" python models/html_enhanced/random_forest_html.py") return print("\nCommands:") print(" file - Analyze HTML file") print(" url - Download and analyze URL") print(" quit - Exit") print("-"*80) while True: try: user_input = input(f"\n{Fore.CYAN}Enter command: {Style.RESET_ALL}").strip() if not user_input: continue if user_input.lower() in ['quit', 'exit', 'q']: print("\nGoodbye!") break parts = user_input.split(maxsplit=1) command = parts[0].lower() if command == 'file' and len(parts) == 2: file_path = parts[1].strip() if Path(file_path).exists(): predictor.predict_from_file(file_path) else: print(f"{Fore.RED}File not found: {file_path}") elif command == 'url' and len(parts) == 2: url = parts[1].strip() predictor.predict_from_url(url) else: print(f"{Fore.YELLOW}Invalid command. Use: file or url ") except KeyboardInterrupt: print("\n\nGoodbye!") break except Exception as e: print(f"{Fore.RED}Error: {e}") def main(): """Main function.""" if len(sys.argv) > 1: # Command line mode predictor = HTMLPhishingPredictor() arg = sys.argv[1] if Path(arg).exists(): # File path predictor.predict_from_file(arg) elif arg.startswith('http'): # URL predictor.predict_from_url(arg) else: print(f"Invalid input: {arg}") print("\nUsage:") print(" python scripts/predict_html.py ") print(" python scripts/predict_html.py ") print(" python scripts/predict_html.py (interactive mode)") else: # Interactive mode interactive_mode() if __name__ == '__main__': main()