cs / app.py
Shivangguptasih's picture
Update app.py
92ce2fe verified
# app.py
from flask import Flask, request, jsonify
from flask_cors import CORS
import pandas as pd
import numpy as np
import re
import whois
import dns.resolver
import tldextract
import requests
from urllib.parse import urlparse, parse_qs
from datetime import datetime
import time
import warnings
import joblib
import os
from functools import lru_cache
import threading
warnings.filterwarnings('ignore')
# Initialize Flask app
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Extended whitelist of known legitimate domains
LEGITIMATE_DOMAINS = [
# Major tech companies
'google.com', 'youtube.com', 'facebook.com', 'instagram.com', 'twitter.com',
'linkedin.com', 'github.com', 'microsoft.com', 'apple.com', 'amazon.com',
'wikipedia.org', 'stackoverflow.com', 'reddit.com', 'netflix.com', 'paypal.com',
# Google services
'accounts.google.com', 'mail.google.com', 'drive.google.com', 'docs.google.com',
# Microsoft services
'office.com', 'outlook.com', 'live.com', 'hotmail.com', 'onedrive.com',
# Apple services
'icloud.com', 'appleid.apple.com', 'developer.apple.com', 'support.apple.com',
# Amazon services
'aws.amazon.com', 'primevideo.com', 'music.amazon.com', 'photos.amazon.com',
# Developer platforms
'npmjs.com', 'pypi.org', 'rubygems.org', 'crates.io', 'packagist.org',
'maven.org', 'nuget.org', 'cpan.org', 'docker.com', 'kubernetes.io',
'tensorflow.org', 'pytorch.org', 'huggingface.co', 'kaggle.com',
'colab.research.google.com', 'jupyter.org', 'anaconda.com',
# Social media
'pinterest.com', 'snapchat.com', 'tiktok.com', 'whatsapp.com', 'messenger.com',
'telegram.org', 'discord.com', 'twitch.tv', 'medium.com', 'quora.com',
# E-commerce
'ebay.com', 'etsy.com', 'shopify.com', 'aliexpress.com', 'alibaba.com',
'target.com', 'walmart.com', 'bestbuy.com', 'homedepot.com',
# News and media
'cnn.com', 'bbc.com', 'nytimes.com', 'theguardian.com', 'reuters.com',
'bloomberg.com', 'wsj.com', 'forbes.com', 'techcrunch.com', 'wired.com',
# Education
'coursera.org', 'udemy.com', 'edx.org', 'khanacademy.org', 'mit.edu',
'stanford.edu', 'harvard.edu', 'berkeley.edu', 'cmu.edu',
# Cloud providers
'aws.amazon.com', 'azure.microsoft.com', 'cloud.google.com', 'oracle.com',
'ibm.com', 'salesforce.com', 'sap.com', 'adobe.com', 'service-now.com',
# Common legitimate TLDs
'githubusercontent.com', 'github.io', 'wordpress.com', 'wordpress.org',
'bitbucket.org', 'gitlab.com', 'sourceforge.net', 'fandom.com'
]
# Pre-load feature columns to avoid downloading every time
FEATURE_COLUMNS = None
# Thread-safe caches
dns_cache = {}
whois_cache = {}
cache_lock = threading.Lock()
# DNS resolver with timeout and custom nameservers
resolver = dns.resolver.Resolver()
resolver.timeout = 1.0
resolver.lifetime = 1.0
resolver.nameservers = ['8.8.8.8', '8.8.4.4']
# Load the models and feature columns once at startup
def initialize():
global FEATURE_COLUMNS
try:
# Load models
print("Loading models...")
rf_model = joblib.load('phishing_detector_rf_grega.pkl')
xgb_model = joblib.load('phishing_detector_xgb_grega.pkl')
scaler = joblib.load('feature_scaler_grega.pkl')
print("Models loaded successfully!")
# Load feature columns
print("Loading feature columns...")
url = "https://raw.githubusercontent.com/GregaVrbancic/Phishing-Dataset/master/dataset_small.csv"
df = pd.read_csv(url)
FEATURE_COLUMNS = df.drop('phishing', axis=1).columns.tolist()
print(f"Feature columns loaded: {len(FEATURE_COLUMNS)} features")
return rf_model, xgb_model, scaler
except Exception as e:
print(f"Initialization error: {e}")
return None, None, None
# Initialize models at startup
print("Initializing models...")
rf_model, xgb_model, scaler = initialize()
if not rf_model:
print("Failed to initialize models!")
# We'll continue and return error messages in the API
# Fast DNS lookup with caching
@lru_cache(maxsize=5000)
def fast_dns_lookup(domain, record_type='A'):
"""Fast DNS lookup with caching."""
cache_key = f"{domain}:{record_type}"
# Check cache
with cache_lock:
if cache_key in dns_cache:
return dns_cache[cache_key]
try:
answers = resolver.resolve(domain, record_type)
result = len(answers)
if record_type == 'A':
result = answers.rrset.ttl if answers else 0
elif record_type == 'TXT':
result = 1 if any('v=spf1' in str(r) for r in answers) else 0
# Cache result
with cache_lock:
dns_cache[cache_key] = result
return result
except:
return 0
# Fast WHOIS lookup with caching
@lru_cache(maxsize=5000)
def fast_whois_lookup(domain):
"""Fast WHOIS lookup with caching."""
# Check cache
with cache_lock:
if domain in whois_cache:
return whois_cache[domain]
try:
w = whois.whois(domain, timeout=2)
creation_date = w.creation_date
expiration_date = w.expiration_date
if isinstance(creation_date, list):
creation_date = creation_date[0]
if isinstance(expiration_date, list):
expiration_date = expiration_date[0]
domain_age = (datetime.now() - creation_date).days if creation_date else 0
days_to_expiration = (expiration_date - datetime.now()).days if expiration_date else 0
result = (domain_age, days_to_expiration)
# Cache result
with cache_lock:
whois_cache[domain] = result
return result
except:
return (0, 0)
# Check if domain is whitelisted (optimized)
def is_domain_whitelisted(domain):
"""Check if domain or any of its parent domains are in the whitelist."""
# Remove port if present
domain = domain.split(':')[0]
# Direct match (most common case)
if domain in LEGITIMATE_DOMAINS:
return True
# Check parent domains
parts = domain.split('.')
for i in range(1, len(parts)):
parent_domain = '.'.join(parts[i:])
if parent_domain in LEGITIMATE_DOMAINS:
return True
return False
# Optimized feature extraction
def extract_features_fast(url):
"""Extract features with optimized performance."""
features = {}
try:
parsed = urlparse(url)
except:
return {col: 0 for col in FEATURE_COLUMNS}
ext = tldextract.extract(url)
domain = parsed.netloc
path = parsed.path
query = parsed.query
# Fast URL-based features (single pass)
url_chars = {
'.': 'qty_dot_url', '-': 'qty_hyphen_url', '_': 'qty_underline_url',
'/': 'qty_slash_url', '?': 'qty_questionmark_url', '=': 'qty_equal_url',
'@': 'qty_at_url', '&': 'qty_and_url', '!': 'qty_exclamation_url',
' ': 'qty_space_url', '~': 'qty_tilde_url', ',': 'qty_comma_url',
'+': 'qty_plus_url', '*': 'qty_asterisk_url', '#': 'qty_hashtag_url',
'$': 'qty_dollar_url', '%': 'qty_percent_url'
}
for char, feature in url_chars.items():
features[feature] = url.count(char)
features['qty_tld_url'] = len(ext.suffix) if ext.suffix else 0
features['length_url'] = len(url)
# Fast domain-based features
domain_chars = {
'.': 'qty_dot_domain', '-': 'qty_hyphen_domain', '_': 'qty_underline_domain',
'/': 'qty_slash_domain', '?': 'qty_questionmark_domain', '=': 'qty_equal_domain',
'@': 'qty_at_domain', '&': 'qty_and_domain', '!': 'qty_exclamation_domain',
' ': 'qty_space_domain', '~': 'qty_tilde_domain', ',': 'qty_comma_domain',
'+': 'qty_plus_domain', '*': 'qty_asterisk_domain', '#': 'qty_hashtag_domain',
'$': 'qty_dollar_domain', '%': 'qty_percent_domain'
}
for char, feature in domain_chars.items():
features[feature] = domain.count(char)
vowels = 'aeiouAEIOU'
features['qty_vowels_domain'] = sum(1 for c in domain if c in vowels)
features['domain_length'] = len(domain)
features['domain_in_ip'] = 1 if re.match(r'^\d+\.\d+\.\d+\.\d+$', domain) else 0
features['server_client_domain'] = 1 if any(keyword in domain.lower() for keyword in ['server', 'client']) else 0
# Directory features (optimized)
dir_chars = {
'.': 'qty_dot_directory', '-': 'qty_hyphen_directory', '_': 'qty_underline_directory',
'/': 'qty_slash_directory', '?': 'qty_questionmark_directory', '=': 'qty_equal_directory',
'@': 'qty_at_directory', '&': 'qty_and_directory', '!': 'qty_exclamation_directory',
' ': 'qty_space_directory', '~': 'qty_tilde_directory', ',': 'qty_comma_directory',
'+': 'qty_plus_directory', '*': 'qty_asterisk_directory', '#': 'qty_hashtag_directory',
'$': 'qty_dollar_directory', '%': 'qty_percent_directory'
}
for char, feature in dir_chars.items():
features[feature] = path.count(char)
features['directory_length'] = len(path)
# File features
file_name = path.split('/')[-1] if '/' in path else ''
file_chars = {
'.': 'qty_dot_file', '-': 'qty_hyphen_file', '_': 'qty_underline_file',
'/': 'qty_slash_file', '?': 'qty_questionmark_file', '=': 'qty_equal_file',
'@': 'qty_at_file', '&': 'qty_and_file', '!': 'qty_exclamation_file',
' ': 'qty_space_file', '~': 'qty_tilde_file', ',': 'qty_comma_file',
'+': 'qty_plus_file', '*': 'qty_asterisk_file', '#': 'qty_hashtag_file',
'$': 'qty_dollar_file', '%': 'qty_percent_file'
}
for char, feature in file_chars.items():
features[feature] = file_name.count(char)
features['file_length'] = len(file_name)
# Parameters features
param_chars = {
'.': 'qty_dot_params', '-': 'qty_hyphen_params', '_': 'qty_underline_params',
'/': 'qty_slash_params', '?': 'qty_questionmark_params', '=': 'qty_equal_params',
'@': 'qty_at_params', '&': 'qty_and_params', '!': 'qty_exclamation_params',
' ': 'qty_space_params', '~': 'qty_tilde_params', ',': 'qty_comma_params',
'+': 'qty_plus_params', '*': 'qty_asterisk_params', '#': 'qty_hashtag_params',
'$': 'qty_dollar_params', '%': 'qty_percent_params'
}
for char, feature in param_chars.items():
features[feature] = query.count(char)
features['params_length'] = len(query)
features['tld_present_params'] = 1 if ext.suffix and ext.suffix in query else 0
params = parse_qs(query)
features['qty_params'] = len(params)
features['email_in_url'] = 1 if re.search(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', url) else 0
# DNS and network features (optimized with caching)
if is_domain_whitelisted(domain):
# Use preset values for whitelisted domains
features['time_response'] = 0.1
features['domain_spf'] = 1
features['asn_ip'] = 15169
features['time_domain_activation'] = 3650 # 10 years
features['time_domain_expiration'] = 365
features['qty_ip_resolved'] = 1
features['qty_nameservers'] = 4
features['qty_mx_servers'] = 1
features['ttl_hostname'] = 300
features['tls_ssl_certificate'] = 1
features['qty_redirects'] = 0
features['url_google_index'] = 1
features['domain_google_index'] = 1
features['url_shortened'] = 0
else:
# Use cached DNS lookups
start_time = time.time()
features['time_response'] = 0.1 # Default value
features['domain_spf'] = fast_dns_lookup(domain, 'TXT')
features['asn_ip'] = 0
# Use cached WHOIS
domain_age, days_to_expiration = fast_whois_lookup(domain)
features['time_domain_activation'] = domain_age
features['time_domain_expiration'] = days_to_expiration
# DNS lookups
features['qty_ip_resolved'] = fast_dns_lookup(domain, 'A')
features['qty_nameservers'] = fast_dns_lookup(domain, 'NS')
features['qty_mx_servers'] = fast_dns_lookup(domain, 'MX')
features['ttl_hostname'] = fast_dns_lookup(domain, 'A')
features['tls_ssl_certificate'] = 0
features['qty_redirects'] = 0
features['url_google_index'] = 0
features['domain_google_index'] = 0
shorteners = ['bit.ly', 'tinyurl.com', 'short.link', 'goo.gl', 't.co']
features['url_shortened'] = 1 if any(s in domain for s in shorteners) else 0
return features
# Fast prediction function
def predict_phishing_fast(url):
"""Fast prediction with optimized feature extraction."""
start_time = time.time()
# Ensure URL has a protocol
if not url.startswith(('http://', 'https://')):
url = 'http://' + url
# Extract domain for whitelist check
parsed = urlparse(url)
domain = parsed.netloc
# Check if domain is whitelisted (immediate return)
if is_domain_whitelisted(domain):
return {
'url': url,
'rf_phishing_probability': 0.01,
'xgb_phishing_probability': 0.01,
'avg_phishing_probability': 0.01,
'is_phishing': False,
'confidence': 'High',
'whitelisted': True,
'processing_time_ms': round((time.time() - start_time) * 1000, 2)
}
# Extract features
features = extract_features_fast(url)
# Convert to DataFrame
features_df = pd.DataFrame([features])
# Ensure all columns match training data
missing_cols = set(FEATURE_COLUMNS) - set(features_df.columns)
for col in missing_cols:
features_df[col] = 0
features_df = features_df[FEATURE_COLUMNS]
# Scale features
features_scaled = scaler.transform(features_df)
# Predict
rf_prob = rf_model.predict_proba(features_scaled)[0][1]
xgb_prob = xgb_model.predict_proba(features_scaled)[0][1]
# Use average probability
avg_prob = (rf_prob + xgb_prob) / 2
# Determine confidence level
if avg_prob > 0.8 or avg_prob < 0.2:
confidence = 'High'
elif avg_prob > 0.65 or avg_prob < 0.35:
confidence = 'Medium'
else:
confidence = 'Low'
return {
'url': url,
'rf_phishing_probability': round(rf_prob, 4),
'xgb_phishing_probability': round(xgb_prob, 4),
'avg_phishing_probability': round(avg_prob, 4),
'is_phishing': avg_prob > 0.8,
'confidence': confidence,
'whitelisted': False,
'processing_time_ms': round((time.time() - start_time) * 1000, 2)
}
# API endpoint for checking URLs
@app.route('/api/check', methods=['POST', 'OPTIONS'])
def check_url():
# Handle OPTIONS request for CORS
if request.method == 'OPTIONS':
return '', 200
try:
# Get JSON data from request
data = request.get_json()
if not data or 'url' not in data:
return jsonify({
'error': 'Missing URL parameter',
'message': 'Please provide a URL in the request body'
}), 400
url = data['url']
# Validate URL format
if not isinstance(url, str) or not url.strip():
return jsonify({
'error': 'Invalid URL',
'message': 'URL must be a non-empty string'
}), 400
# Make prediction
result = predict_phishing_fast(url)
# Return JSON response
return jsonify({
'success': True,
'data': result
})
except Exception as e:
return jsonify({
'success': False,
'error': 'Internal server error',
'message': str(e)
}), 500
# Health check endpoint
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
'status': 'healthy',
'message': 'Phishing detection API is running',
'models_loaded': rf_model is not None and xgb_model is not None and scaler is not None,
'numpy_version': np.__version__
})
# Root endpoint with API documentation
@app.route('/', methods=['GET'])
def root():
return jsonify({
'message': 'Phishing Detection API',
'version': '1.0.0',
'numpy_version': np.__version__,
'endpoints': {
'check': '/api/check (POST)',
'health': '/health (GET)'
},
'usage': {
'check': 'POST {"url": "https://example.com"} to /api/check'
},
'example': {
'request': {
'url': 'http://paypal.secure.login-update.com'
},
'response': {
'success': True,
'data': {
'url': 'http://paypal.secure.login-update.com',
'is_phishing': True,
'confidence': 'High',
'avg_phishing_probability': 0.9137
}
}
}
})
# Run the app
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)