Spaces:
Sleeping
Sleeping
| # app.py | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import pandas as pd | |
| import numpy as np | |
| import re | |
| import whois | |
| import dns.resolver | |
| import tldextract | |
| import requests | |
| from urllib.parse import urlparse, parse_qs | |
| from datetime import datetime | |
| import time | |
| import warnings | |
| import joblib | |
| import os | |
| from functools import lru_cache | |
| import threading | |
| warnings.filterwarnings('ignore') | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Extended whitelist of known legitimate domains | |
| LEGITIMATE_DOMAINS = [ | |
| # Major tech companies | |
| 'google.com', 'youtube.com', 'facebook.com', 'instagram.com', 'twitter.com', | |
| 'linkedin.com', 'github.com', 'microsoft.com', 'apple.com', 'amazon.com', | |
| 'wikipedia.org', 'stackoverflow.com', 'reddit.com', 'netflix.com', 'paypal.com', | |
| # Google services | |
| 'accounts.google.com', 'mail.google.com', 'drive.google.com', 'docs.google.com', | |
| # Microsoft services | |
| 'office.com', 'outlook.com', 'live.com', 'hotmail.com', 'onedrive.com', | |
| # Apple services | |
| 'icloud.com', 'appleid.apple.com', 'developer.apple.com', 'support.apple.com', | |
| # Amazon services | |
| 'aws.amazon.com', 'primevideo.com', 'music.amazon.com', 'photos.amazon.com', | |
| # Developer platforms | |
| 'npmjs.com', 'pypi.org', 'rubygems.org', 'crates.io', 'packagist.org', | |
| 'maven.org', 'nuget.org', 'cpan.org', 'docker.com', 'kubernetes.io', | |
| 'tensorflow.org', 'pytorch.org', 'huggingface.co', 'kaggle.com', | |
| 'colab.research.google.com', 'jupyter.org', 'anaconda.com', | |
| # Social media | |
| 'pinterest.com', 'snapchat.com', 'tiktok.com', 'whatsapp.com', 'messenger.com', | |
| 'telegram.org', 'discord.com', 'twitch.tv', 'medium.com', 'quora.com', | |
| # E-commerce | |
| 'ebay.com', 'etsy.com', 'shopify.com', 'aliexpress.com', 'alibaba.com', | |
| 'target.com', 'walmart.com', 'bestbuy.com', 'homedepot.com', | |
| # News and media | |
| 'cnn.com', 'bbc.com', 'nytimes.com', 'theguardian.com', 'reuters.com', | |
| 'bloomberg.com', 'wsj.com', 'forbes.com', 'techcrunch.com', 'wired.com', | |
| # Education | |
| 'coursera.org', 'udemy.com', 'edx.org', 'khanacademy.org', 'mit.edu', | |
| 'stanford.edu', 'harvard.edu', 'berkeley.edu', 'cmu.edu', | |
| # Cloud providers | |
| 'aws.amazon.com', 'azure.microsoft.com', 'cloud.google.com', 'oracle.com', | |
| 'ibm.com', 'salesforce.com', 'sap.com', 'adobe.com', 'service-now.com', | |
| # Common legitimate TLDs | |
| 'githubusercontent.com', 'github.io', 'wordpress.com', 'wordpress.org', | |
| 'bitbucket.org', 'gitlab.com', 'sourceforge.net', 'fandom.com' | |
| ] | |
| # Pre-load feature columns to avoid downloading every time | |
| FEATURE_COLUMNS = None | |
| # Thread-safe caches | |
| dns_cache = {} | |
| whois_cache = {} | |
| cache_lock = threading.Lock() | |
| # DNS resolver with timeout and custom nameservers | |
| resolver = dns.resolver.Resolver() | |
| resolver.timeout = 1.0 | |
| resolver.lifetime = 1.0 | |
| resolver.nameservers = ['8.8.8.8', '8.8.4.4'] | |
| # Load the models and feature columns once at startup | |
| def initialize(): | |
| global FEATURE_COLUMNS | |
| try: | |
| # Load models | |
| print("Loading models...") | |
| rf_model = joblib.load('phishing_detector_rf_grega.pkl') | |
| xgb_model = joblib.load('phishing_detector_xgb_grega.pkl') | |
| scaler = joblib.load('feature_scaler_grega.pkl') | |
| print("Models loaded successfully!") | |
| # Load feature columns | |
| print("Loading feature columns...") | |
| url = "https://raw.githubusercontent.com/GregaVrbancic/Phishing-Dataset/master/dataset_small.csv" | |
| df = pd.read_csv(url) | |
| FEATURE_COLUMNS = df.drop('phishing', axis=1).columns.tolist() | |
| print(f"Feature columns loaded: {len(FEATURE_COLUMNS)} features") | |
| return rf_model, xgb_model, scaler | |
| except Exception as e: | |
| print(f"Initialization error: {e}") | |
| return None, None, None | |
| # Initialize models at startup | |
| print("Initializing models...") | |
| rf_model, xgb_model, scaler = initialize() | |
| if not rf_model: | |
| print("Failed to initialize models!") | |
| # We'll continue and return error messages in the API | |
| # Fast DNS lookup with caching | |
| def fast_dns_lookup(domain, record_type='A'): | |
| """Fast DNS lookup with caching.""" | |
| cache_key = f"{domain}:{record_type}" | |
| # Check cache | |
| with cache_lock: | |
| if cache_key in dns_cache: | |
| return dns_cache[cache_key] | |
| try: | |
| answers = resolver.resolve(domain, record_type) | |
| result = len(answers) | |
| if record_type == 'A': | |
| result = answers.rrset.ttl if answers else 0 | |
| elif record_type == 'TXT': | |
| result = 1 if any('v=spf1' in str(r) for r in answers) else 0 | |
| # Cache result | |
| with cache_lock: | |
| dns_cache[cache_key] = result | |
| return result | |
| except: | |
| return 0 | |
| # Fast WHOIS lookup with caching | |
| def fast_whois_lookup(domain): | |
| """Fast WHOIS lookup with caching.""" | |
| # Check cache | |
| with cache_lock: | |
| if domain in whois_cache: | |
| return whois_cache[domain] | |
| try: | |
| w = whois.whois(domain, timeout=2) | |
| creation_date = w.creation_date | |
| expiration_date = w.expiration_date | |
| if isinstance(creation_date, list): | |
| creation_date = creation_date[0] | |
| if isinstance(expiration_date, list): | |
| expiration_date = expiration_date[0] | |
| domain_age = (datetime.now() - creation_date).days if creation_date else 0 | |
| days_to_expiration = (expiration_date - datetime.now()).days if expiration_date else 0 | |
| result = (domain_age, days_to_expiration) | |
| # Cache result | |
| with cache_lock: | |
| whois_cache[domain] = result | |
| return result | |
| except: | |
| return (0, 0) | |
| # Check if domain is whitelisted (optimized) | |
| def is_domain_whitelisted(domain): | |
| """Check if domain or any of its parent domains are in the whitelist.""" | |
| # Remove port if present | |
| domain = domain.split(':')[0] | |
| # Direct match (most common case) | |
| if domain in LEGITIMATE_DOMAINS: | |
| return True | |
| # Check parent domains | |
| parts = domain.split('.') | |
| for i in range(1, len(parts)): | |
| parent_domain = '.'.join(parts[i:]) | |
| if parent_domain in LEGITIMATE_DOMAINS: | |
| return True | |
| return False | |
| # Optimized feature extraction | |
| def extract_features_fast(url): | |
| """Extract features with optimized performance.""" | |
| features = {} | |
| try: | |
| parsed = urlparse(url) | |
| except: | |
| return {col: 0 for col in FEATURE_COLUMNS} | |
| ext = tldextract.extract(url) | |
| domain = parsed.netloc | |
| path = parsed.path | |
| query = parsed.query | |
| # Fast URL-based features (single pass) | |
| url_chars = { | |
| '.': 'qty_dot_url', '-': 'qty_hyphen_url', '_': 'qty_underline_url', | |
| '/': 'qty_slash_url', '?': 'qty_questionmark_url', '=': 'qty_equal_url', | |
| '@': 'qty_at_url', '&': 'qty_and_url', '!': 'qty_exclamation_url', | |
| ' ': 'qty_space_url', '~': 'qty_tilde_url', ',': 'qty_comma_url', | |
| '+': 'qty_plus_url', '*': 'qty_asterisk_url', '#': 'qty_hashtag_url', | |
| '$': 'qty_dollar_url', '%': 'qty_percent_url' | |
| } | |
| for char, feature in url_chars.items(): | |
| features[feature] = url.count(char) | |
| features['qty_tld_url'] = len(ext.suffix) if ext.suffix else 0 | |
| features['length_url'] = len(url) | |
| # Fast domain-based features | |
| domain_chars = { | |
| '.': 'qty_dot_domain', '-': 'qty_hyphen_domain', '_': 'qty_underline_domain', | |
| '/': 'qty_slash_domain', '?': 'qty_questionmark_domain', '=': 'qty_equal_domain', | |
| '@': 'qty_at_domain', '&': 'qty_and_domain', '!': 'qty_exclamation_domain', | |
| ' ': 'qty_space_domain', '~': 'qty_tilde_domain', ',': 'qty_comma_domain', | |
| '+': 'qty_plus_domain', '*': 'qty_asterisk_domain', '#': 'qty_hashtag_domain', | |
| '$': 'qty_dollar_domain', '%': 'qty_percent_domain' | |
| } | |
| for char, feature in domain_chars.items(): | |
| features[feature] = domain.count(char) | |
| vowels = 'aeiouAEIOU' | |
| features['qty_vowels_domain'] = sum(1 for c in domain if c in vowels) | |
| features['domain_length'] = len(domain) | |
| features['domain_in_ip'] = 1 if re.match(r'^\d+\.\d+\.\d+\.\d+$', domain) else 0 | |
| features['server_client_domain'] = 1 if any(keyword in domain.lower() for keyword in ['server', 'client']) else 0 | |
| # Directory features (optimized) | |
| dir_chars = { | |
| '.': 'qty_dot_directory', '-': 'qty_hyphen_directory', '_': 'qty_underline_directory', | |
| '/': 'qty_slash_directory', '?': 'qty_questionmark_directory', '=': 'qty_equal_directory', | |
| '@': 'qty_at_directory', '&': 'qty_and_directory', '!': 'qty_exclamation_directory', | |
| ' ': 'qty_space_directory', '~': 'qty_tilde_directory', ',': 'qty_comma_directory', | |
| '+': 'qty_plus_directory', '*': 'qty_asterisk_directory', '#': 'qty_hashtag_directory', | |
| '$': 'qty_dollar_directory', '%': 'qty_percent_directory' | |
| } | |
| for char, feature in dir_chars.items(): | |
| features[feature] = path.count(char) | |
| features['directory_length'] = len(path) | |
| # File features | |
| file_name = path.split('/')[-1] if '/' in path else '' | |
| file_chars = { | |
| '.': 'qty_dot_file', '-': 'qty_hyphen_file', '_': 'qty_underline_file', | |
| '/': 'qty_slash_file', '?': 'qty_questionmark_file', '=': 'qty_equal_file', | |
| '@': 'qty_at_file', '&': 'qty_and_file', '!': 'qty_exclamation_file', | |
| ' ': 'qty_space_file', '~': 'qty_tilde_file', ',': 'qty_comma_file', | |
| '+': 'qty_plus_file', '*': 'qty_asterisk_file', '#': 'qty_hashtag_file', | |
| '$': 'qty_dollar_file', '%': 'qty_percent_file' | |
| } | |
| for char, feature in file_chars.items(): | |
| features[feature] = file_name.count(char) | |
| features['file_length'] = len(file_name) | |
| # Parameters features | |
| param_chars = { | |
| '.': 'qty_dot_params', '-': 'qty_hyphen_params', '_': 'qty_underline_params', | |
| '/': 'qty_slash_params', '?': 'qty_questionmark_params', '=': 'qty_equal_params', | |
| '@': 'qty_at_params', '&': 'qty_and_params', '!': 'qty_exclamation_params', | |
| ' ': 'qty_space_params', '~': 'qty_tilde_params', ',': 'qty_comma_params', | |
| '+': 'qty_plus_params', '*': 'qty_asterisk_params', '#': 'qty_hashtag_params', | |
| '$': 'qty_dollar_params', '%': 'qty_percent_params' | |
| } | |
| for char, feature in param_chars.items(): | |
| features[feature] = query.count(char) | |
| features['params_length'] = len(query) | |
| features['tld_present_params'] = 1 if ext.suffix and ext.suffix in query else 0 | |
| params = parse_qs(query) | |
| features['qty_params'] = len(params) | |
| features['email_in_url'] = 1 if re.search(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', url) else 0 | |
| # DNS and network features (optimized with caching) | |
| if is_domain_whitelisted(domain): | |
| # Use preset values for whitelisted domains | |
| features['time_response'] = 0.1 | |
| features['domain_spf'] = 1 | |
| features['asn_ip'] = 15169 | |
| features['time_domain_activation'] = 3650 # 10 years | |
| features['time_domain_expiration'] = 365 | |
| features['qty_ip_resolved'] = 1 | |
| features['qty_nameservers'] = 4 | |
| features['qty_mx_servers'] = 1 | |
| features['ttl_hostname'] = 300 | |
| features['tls_ssl_certificate'] = 1 | |
| features['qty_redirects'] = 0 | |
| features['url_google_index'] = 1 | |
| features['domain_google_index'] = 1 | |
| features['url_shortened'] = 0 | |
| else: | |
| # Use cached DNS lookups | |
| start_time = time.time() | |
| features['time_response'] = 0.1 # Default value | |
| features['domain_spf'] = fast_dns_lookup(domain, 'TXT') | |
| features['asn_ip'] = 0 | |
| # Use cached WHOIS | |
| domain_age, days_to_expiration = fast_whois_lookup(domain) | |
| features['time_domain_activation'] = domain_age | |
| features['time_domain_expiration'] = days_to_expiration | |
| # DNS lookups | |
| features['qty_ip_resolved'] = fast_dns_lookup(domain, 'A') | |
| features['qty_nameservers'] = fast_dns_lookup(domain, 'NS') | |
| features['qty_mx_servers'] = fast_dns_lookup(domain, 'MX') | |
| features['ttl_hostname'] = fast_dns_lookup(domain, 'A') | |
| features['tls_ssl_certificate'] = 0 | |
| features['qty_redirects'] = 0 | |
| features['url_google_index'] = 0 | |
| features['domain_google_index'] = 0 | |
| shorteners = ['bit.ly', 'tinyurl.com', 'short.link', 'goo.gl', 't.co'] | |
| features['url_shortened'] = 1 if any(s in domain for s in shorteners) else 0 | |
| return features | |
| # Fast prediction function | |
| def predict_phishing_fast(url): | |
| """Fast prediction with optimized feature extraction.""" | |
| start_time = time.time() | |
| # Ensure URL has a protocol | |
| if not url.startswith(('http://', 'https://')): | |
| url = 'http://' + url | |
| # Extract domain for whitelist check | |
| parsed = urlparse(url) | |
| domain = parsed.netloc | |
| # Check if domain is whitelisted (immediate return) | |
| if is_domain_whitelisted(domain): | |
| return { | |
| 'url': url, | |
| 'rf_phishing_probability': 0.01, | |
| 'xgb_phishing_probability': 0.01, | |
| 'avg_phishing_probability': 0.01, | |
| 'is_phishing': False, | |
| 'confidence': 'High', | |
| 'whitelisted': True, | |
| 'processing_time_ms': round((time.time() - start_time) * 1000, 2) | |
| } | |
| # Extract features | |
| features = extract_features_fast(url) | |
| # Convert to DataFrame | |
| features_df = pd.DataFrame([features]) | |
| # Ensure all columns match training data | |
| missing_cols = set(FEATURE_COLUMNS) - set(features_df.columns) | |
| for col in missing_cols: | |
| features_df[col] = 0 | |
| features_df = features_df[FEATURE_COLUMNS] | |
| # Scale features | |
| features_scaled = scaler.transform(features_df) | |
| # Predict | |
| rf_prob = rf_model.predict_proba(features_scaled)[0][1] | |
| xgb_prob = xgb_model.predict_proba(features_scaled)[0][1] | |
| # Use average probability | |
| avg_prob = (rf_prob + xgb_prob) / 2 | |
| # Determine confidence level | |
| if avg_prob > 0.8 or avg_prob < 0.2: | |
| confidence = 'High' | |
| elif avg_prob > 0.65 or avg_prob < 0.35: | |
| confidence = 'Medium' | |
| else: | |
| confidence = 'Low' | |
| return { | |
| 'url': url, | |
| 'rf_phishing_probability': round(rf_prob, 4), | |
| 'xgb_phishing_probability': round(xgb_prob, 4), | |
| 'avg_phishing_probability': round(avg_prob, 4), | |
| 'is_phishing': avg_prob > 0.8, | |
| 'confidence': confidence, | |
| 'whitelisted': False, | |
| 'processing_time_ms': round((time.time() - start_time) * 1000, 2) | |
| } | |
| # API endpoint for checking URLs | |
| def check_url(): | |
| # Handle OPTIONS request for CORS | |
| if request.method == 'OPTIONS': | |
| return '', 200 | |
| try: | |
| # Get JSON data from request | |
| data = request.get_json() | |
| if not data or 'url' not in data: | |
| return jsonify({ | |
| 'error': 'Missing URL parameter', | |
| 'message': 'Please provide a URL in the request body' | |
| }), 400 | |
| url = data['url'] | |
| # Validate URL format | |
| if not isinstance(url, str) or not url.strip(): | |
| return jsonify({ | |
| 'error': 'Invalid URL', | |
| 'message': 'URL must be a non-empty string' | |
| }), 400 | |
| # Make prediction | |
| result = predict_phishing_fast(url) | |
| # Return JSON response | |
| return jsonify({ | |
| 'success': True, | |
| 'data': result | |
| }) | |
| except Exception as e: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Internal server error', | |
| 'message': str(e) | |
| }), 500 | |
| # Health check endpoint | |
| def health_check(): | |
| return jsonify({ | |
| 'status': 'healthy', | |
| 'message': 'Phishing detection API is running', | |
| 'models_loaded': rf_model is not None and xgb_model is not None and scaler is not None, | |
| 'numpy_version': np.__version__ | |
| }) | |
| # Root endpoint with API documentation | |
| def root(): | |
| return jsonify({ | |
| 'message': 'Phishing Detection API', | |
| 'version': '1.0.0', | |
| 'numpy_version': np.__version__, | |
| 'endpoints': { | |
| 'check': '/api/check (POST)', | |
| 'health': '/health (GET)' | |
| }, | |
| 'usage': { | |
| 'check': 'POST {"url": "https://example.com"} to /api/check' | |
| }, | |
| 'example': { | |
| 'request': { | |
| 'url': 'http://paypal.secure.login-update.com' | |
| }, | |
| 'response': { | |
| 'success': True, | |
| 'data': { | |
| 'url': 'http://paypal.secure.login-update.com', | |
| 'is_phishing': True, | |
| 'confidence': 'High', | |
| 'avg_phishing_probability': 0.9137 | |
| } | |
| } | |
| } | |
| }) | |
| # Run the app | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port) |