Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import joblib | |
| import json | |
| import numpy as np | |
| import re | |
| from urllib.parse import urlparse | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| # Define the model and username | |
| MODEL_NAME = "XGBoost" | |
| HF_USERNAME = "Devishetty100" | |
| CUSTOM_MODEL_NAME = "NeoGuardianAI" | |
| REPO_ID = f"{HF_USERNAME}/{CUSTOM_MODEL_NAME.lower()}" | |
| # List of trusted domains that should always be considered safe | |
| TRUSTED_DOMAINS = [ | |
| 'huggingface.co', | |
| 'github.com', | |
| 'google.com', | |
| 'microsoft.com', | |
| 'apple.com', | |
| 'amazon.com', | |
| 'facebook.com', | |
| 'twitter.com', | |
| 'linkedin.com', | |
| 'youtube.com', | |
| 'wikipedia.org' | |
| ] | |
| # Load model files (either from local files or Hugging Face Hub) | |
| def load_model_files(): | |
| try: | |
| print(f"Attempting to download model from Hugging Face Hub: {REPO_ID}") | |
| # Try to list files in the repository to see what's available | |
| try: | |
| from huggingface_hub import list_repo_files | |
| files = list_repo_files(repo_id=REPO_ID) | |
| print(f"Files available in the repository: {files}") | |
| except Exception as list_error: | |
| print(f"Error listing repository files: {list_error}") | |
| # Use lowercase 'xgboost' instead of MODEL_NAME.lower() to match the actual filename | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename="xgboost_model.joblib") | |
| print(f"Downloaded model file to: {model_path}") | |
| scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib") | |
| feature_names_path = hf_hub_download(repo_id=REPO_ID, filename="feature_names.json") | |
| # Load the model and preprocessing components | |
| model = joblib.load(model_path) | |
| scaler = joblib.load(scaler_path) | |
| # Load feature names | |
| with open(feature_names_path, 'r') as f: | |
| feature_names = json.load(f) | |
| print("Successfully downloaded model from Hugging Face Hub.") | |
| return model, scaler, feature_names | |
| except Exception as hub_error: | |
| print(f"Error downloading from Hugging Face Hub: {hub_error}") | |
| # If downloading fails, try to load from local files | |
| try: | |
| print("Attempting to load model from local files...") | |
| # Try with the correct lowercase name | |
| model = joblib.load("xgboost_model.joblib") | |
| print("Successfully loaded xgboost_model.joblib") | |
| scaler = joblib.load("scaler.joblib") | |
| with open("feature_names.json", 'r') as f: | |
| feature_names = json.load(f) | |
| print("Successfully loaded model from local files.") | |
| return model, scaler, feature_names | |
| except Exception as local_error: | |
| print(f"Error loading from local files: {local_error}") | |
| raise RuntimeError("Failed to load model from both Hugging Face Hub and local files.") | |
| # Extract features from URL | |
| def extract_features(url): | |
| """Extract features from a URL for model prediction.""" | |
| features = {} | |
| # Basic URL properties | |
| features['length_url'] = len(url) | |
| # Parse URL | |
| parsed_url = urlparse(url) | |
| hostname = parsed_url.netloc | |
| path = parsed_url.path | |
| # Hostname features | |
| features['length_hostname'] = len(hostname) | |
| features['ip'] = 1 if re.match(r'\d+\.\d+\.\d+\.\d+', hostname) else 0 | |
| # Count special characters | |
| features['nb_dots'] = url.count('.') | |
| features['nb_hyphens'] = url.count('-') | |
| features['nb_at'] = url.count('@') | |
| features['nb_qm'] = url.count('?') | |
| features['nb_and'] = url.count('&') | |
| features['nb_or'] = url.count('|') | |
| features['nb_eq'] = url.count('=') | |
| features['nb_underscore'] = url.count('_') | |
| features['nb_tilde'] = url.count('~') | |
| features['nb_percent'] = url.count('%') | |
| features['nb_slash'] = url.count('/') | |
| features['nb_star'] = url.count('*') | |
| features['nb_colon'] = url.count(':') | |
| features['nb_comma'] = url.count(',') | |
| features['nb_semicolumn'] = url.count(';') | |
| features['nb_dollar'] = url.count('$') | |
| features['nb_space'] = url.count(' ') | |
| # Other URL features | |
| features['nb_www'] = 1 if 'www' in hostname else 0 | |
| features['nb_com'] = 1 if '.com' in hostname else 0 | |
| features['nb_dslash'] = url.count('//') | |
| features['http_in_path'] = 1 if 'http' in path else 0 | |
| features['https_token'] = 1 if 'https' in url and 'http://' not in url else 0 | |
| # Ratio features | |
| digits_count = sum(c.isdigit() for c in url) | |
| features['ratio_digits_url'] = digits_count / len(url) if len(url) > 0 else 0 | |
| features['ratio_digits_host'] = sum(c.isdigit() for c in hostname) / len(hostname) if len(hostname) > 0 else 0 | |
| # Punycode | |
| features['punycode'] = 1 if 'xn--' in hostname else 0 | |
| # Port | |
| features['port'] = 1 if ':' in hostname and any(c.isdigit() for c in hostname.split(':')[1]) else 0 | |
| # TLD features | |
| tlds = ['.com', '.org', '.net', '.edu', '.gov', '.mil', '.int'] | |
| features['tld_in_path'] = 1 if any(tld in path for tld in tlds) else 0 | |
| features['tld_in_subdomain'] = 1 if hostname.count('.') > 1 and any(tld in hostname.split('.')[0] for tld in tlds) else 0 | |
| # Subdomain features | |
| features['abnormal_subdomain'] = 1 if hostname.count('.') > 2 else 0 | |
| features['nb_subdomains'] = hostname.count('.') | |
| # Other suspicious features | |
| features['prefix_suffix'] = 1 if '-' in hostname else 0 | |
| features['random_domain'] = 1 if len(hostname) > 12 and sum(c.isdigit() for c in hostname) > 4 else 0 | |
| # Shortening service | |
| shortening_services = ['bit.ly', 'goo.gl', 'tinyurl.com', 't.co', 'tr.im', 'is.gd', 'cli.gs', 'ow.ly', 'yfrog.com', 'migre.me', 'ff.im', 'tiny.cc', 'url4.eu', 'twit.ac', 'su.pr', 'twurl.nl', 'snipurl.com', 'short.to', 'budurl.com', 'ping.fm', 'post.ly', 'just.as', 'bkite.com', 'snipr.com', 'fic.kr', 'loopt.us', 'doiop.com', 'twitthis.com', 'htxt.it', 'ak.im', 'shar.es', 'kl.am', 'wp.me', 'rubyurl.com', 'om.ly', 'to.ly', 'bit.do', 't.co', 'lnkd.in', 'db.tt', 'qr.ae', 'adf.ly', 'goo.gl', 'bitly.com', 'cur.lv', 'tinyurl.com', 'ow.ly', 'bit.ly', 'ity.im', 'q.gs', 'is.gd', 'po.st', 'bc.vc', 'twitthis.com', 'u.to', 'j.mp', 'buzurl.com', 'cutt.us', 'u.bb', 'yourls.org', 'x.co', 'prettylinkpro.com', 'scrnch.me', 'filoops.info', 'vzturl.com', 'qr.net', '1url.com', 'tweez.me', 'v.gd', 'tr.im', 'link.zip.net'] | |
| features['shortening_service'] = 1 if any(service in hostname for service in shortening_services) else 0 | |
| # Path features | |
| features['path_extension'] = 1 if '.' in path.split('/')[-1] else 0 | |
| # Fill in remaining features with default values | |
| # These would normally be computed with more complex analysis | |
| for feature in ['nb_redirection', 'nb_external_redirection', 'length_words_raw', | |
| 'char_repeat', 'shortest_words_raw', 'shortest_word_host', | |
| 'shortest_word_path', 'longest_words_raw', 'longest_word_host', | |
| 'longest_word_path', 'avg_words_raw', 'avg_word_host', | |
| 'avg_word_path', 'phish_hints', 'domain_in_brand', | |
| 'brand_in_subdomain', 'brand_in_path', 'suspecious_tld', | |
| 'statistical_report', 'nb_hyperlinks', 'ratio_intHyperlinks', | |
| 'ratio_extHyperlinks', 'ratio_nullHyperlinks', 'nb_extCSS', | |
| 'ratio_intRedirection', 'ratio_extRedirection', 'ratio_intErrors', | |
| 'ratio_extErrors', 'login_form', 'external_favicon', | |
| 'links_in_tags', 'submit_email', 'ratio_intMedia', | |
| 'ratio_extMedia', 'sfh', 'iframe', 'popup_window', | |
| 'safe_anchor', 'onmouseover', 'right_clic', 'empty_title', | |
| 'domain_in_title', 'domain_with_copyright', 'whois_registered_domain', | |
| 'domain_registration_length', 'domain_age', 'web_traffic', | |
| 'dns_record', 'google_index', 'page_rank']: | |
| if feature not in features: | |
| features[feature] = 0 | |
| return features | |
| # Load model and components | |
| try: | |
| model, scaler, feature_names = load_model_files() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Create dummy model and components for demo purposes | |
| print("Using dummy model for demonstration purposes.") | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestClassifier | |
| # Create a dummy model | |
| model = RandomForestClassifier(n_estimators=10) | |
| model.fit(np.array([[0, 0]]), np.array([0])) | |
| model.predict_proba = lambda x: np.array([[0.5, 0.5]]) | |
| # Create dummy scaler and feature names | |
| scaler = lambda x: x | |
| scaler.transform = lambda x: x | |
| feature_names = ['length_url', 'length_hostname'] | |
| def predict_url(url): | |
| """Predict if a URL is phishing or legitimate.""" | |
| if not url or not url.strip(): | |
| return "Please enter a URL", 0.0, "N/A" | |
| try: | |
| # Check if the URL belongs to a trusted domain | |
| parsed_url = urlparse(url) | |
| domain = parsed_url.netloc | |
| # Remove 'www.' prefix if present | |
| if domain.startswith('www.'): | |
| domain = domain[4:] | |
| # Check if the domain or any parent domain is in the trusted list | |
| is_trusted = False | |
| domain_parts = domain.split('.') | |
| for i in range(len(domain_parts) - 1): | |
| check_domain = '.'.join(domain_parts[i:]) | |
| if check_domain in TRUSTED_DOMAINS: | |
| is_trusted = True | |
| break | |
| if is_trusted: | |
| return "Legitimate (Trusted Domain)", 1.0, "✅ SAFE" | |
| # Extract features | |
| url_features = extract_features(url) | |
| # Ensure features are in the correct order | |
| features_array = [] | |
| for feature in feature_names: | |
| if feature in url_features: | |
| features_array.append(url_features[feature]) | |
| else: | |
| features_array.append(0) # Default value if feature is missing | |
| # Scale features | |
| scaled_features = scaler.transform([features_array]) | |
| # Make prediction | |
| prediction = model.predict(scaled_features)[0] | |
| probability = model.predict_proba(scaled_features)[0][1] | |
| # Prepare return values | |
| prediction_text = "Phishing" if prediction == 1 else "Legitimate" | |
| confidence = float(probability) if prediction == 1 else float(1 - probability) | |
| status = "⚠️ UNSAFE" if prediction == 1 else "✅ SAFE" | |
| # Return three separate values for the three output components | |
| return prediction_text, confidence, status | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| return error_msg, 0.0, "Error" | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="NeoGuardianAI - URL Phishing Detection", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # NeoGuardianAI - URL Phishing Detection | |
| This app uses a machine learning model to detect if a URL is legitimate or phishing. | |
| Enter a URL below to check if it's safe or potentially malicious. | |
| """ | |
| ) | |
| with gr.Row(): | |
| url_input = gr.Textbox(label="Enter URL", placeholder="https://example.com") | |
| submit_btn = gr.Button("Check URL", variant="primary") | |
| with gr.Row(): | |
| status_output = gr.Textbox(label="Status") | |
| prediction_output = gr.Textbox(label="Prediction") | |
| confidence_output = gr.Textbox(label="Confidence") | |
| submit_btn.click( | |
| fn=predict_url, | |
| inputs=url_input, | |
| outputs=[ | |
| prediction_output, | |
| confidence_output, | |
| status_output | |
| ] | |
| ) | |
| gr.Markdown( | |
| """ | |
| ## How it works | |
| This model was trained on the [pirocheto/phishing-url](https://huggingface.co/datasets/pirocheto/phishing-url) dataset from Hugging Face. | |
| The model extracts various features from the URL and uses a machine learning algorithm to classify it as legitimate or phishing. | |
| **Note**: While this model is highly accurate, it's not perfect. Always exercise caution when visiting unfamiliar websites. | |
| ## API Usage | |
| You can also use this model via the Hugging Face Inference API: | |
| ```python | |
| import requests | |
| API_URL = "https://api-inference.huggingface.co/models/Devishetty100/neoguardianai" | |
| headers = {"Authorization": "Bearer YOUR_API_TOKEN"} | |
| def query(url): | |
| payload = {"inputs": url} | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| return response.json() | |
| # Example | |
| result = query("https://example.com") | |
| print(result) | |
| ``` | |
| """ | |
| ) | |
| return demo | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() | |