Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| import requests | |
| import re | |
| from urllib.parse import urlparse | |
| from bs4 import BeautifulSoup | |
| import time | |
| import joblib | |
| # --- import your architecture --- | |
| # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py) | |
| # and update the import path accordingly. | |
| from model import DeBERTaLSTMClassifier # <-- your class | |
| # --------- Config ---------- | |
| REPO_ID = "khoa-done/phishing-detector" # HF repo that holds the checkpoint | |
| CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name | |
| MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone | |
| LABELS = ["benign", "phishing"] # adjust to your classes | |
| # If your checkpoint contains hyperparams, you can fetch them like: | |
| # checkpoint.get("config") or checkpoint.get("model_args") | |
| # and pass into DeBERTaLSTMClassifier(**model_args) | |
| # --------- Load model/tokenizer once (global) ---------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME) | |
| checkpoint = torch.load(ckpt_path, map_location=device) | |
| # If you saved hyperparams in the checkpoint, use them: | |
| model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...} | |
| model = DeBERTaLSTMClassifier(**model_args) | |
| # Load state dict and handle missing attention layer for older models | |
| try: | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| except RuntimeError as e: | |
| if "attention" in str(e): | |
| # Old model without attention layer - initialize attention layer and load partial state | |
| state_dict = checkpoint["model_state_dict"] | |
| model_dict = model.state_dict() | |
| # Filter out attention layer parameters | |
| filtered_dict = {k: v for k, v in state_dict.items() if "attention" not in k} | |
| model_dict.update(filtered_dict) | |
| model.load_state_dict(model_dict) | |
| print("Loaded model without attention layer, using newly initialized attention weights") | |
| else: | |
| raise e | |
| model.to(device).eval() | |
| # --------- Load BERT model/tokenizer from Hugging Face Hub ---------- | |
| BERT_MODEL_PATH = "th1enq/bert_checkpoint" # Use Hugging Face Hub model | |
| bert_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_PATH) | |
| bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_PATH) | |
| bert_model.to(device).eval() | |
| # --------- Helper functions ---------- | |
| def is_url(text): | |
| """Check if text is a URL""" | |
| url_pattern = re.compile( | |
| r'^https?://' # http:// or https:// | |
| r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain... | |
| r'localhost|' # localhost... | |
| r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip | |
| r'(?::\d+)?' # optional port | |
| r'(?:/?|[/?]\S+)$', re.IGNORECASE) | |
| return url_pattern.match(text) is not None | |
| def fetch_html_content(url, timeout=10): | |
| """Fetch HTML content from URL""" | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
| } | |
| response = requests.get(url, headers=headers, timeout=timeout, verify=False) | |
| response.raise_for_status() | |
| return response.text, response.status_code | |
| except requests.exceptions.RequestException as e: | |
| return None, f"Request error: {str(e)}" | |
| except Exception as e: | |
| return None, f"General error: {str(e)}" | |
| def predict_single_text(text, text_type="text"): | |
| """Predict for a single text input""" | |
| # Tokenize | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=256 | |
| ) | |
| # DeBERTa typically doesn't use token_type_ids | |
| inputs.pop("token_type_ids", None) | |
| # Move to device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| try: | |
| # Try to get predictions with attention weights | |
| result = model(**inputs, return_attention=True) | |
| if isinstance(result, tuple) and len(result) == 3: | |
| logits, attention_weights, deberta_attentions = result | |
| has_attention = True | |
| else: | |
| logits = result | |
| has_attention = False | |
| except TypeError: | |
| # Fallback for older model without return_attention parameter | |
| logits = model(**inputs) | |
| has_attention = False | |
| probs = F.softmax(logits, dim=-1).squeeze(0).tolist() | |
| # Get tokens for visualization | |
| tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist()) | |
| return probs, tokens, has_attention, attention_weights if has_attention else None | |
| def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7): | |
| """Combine URL and HTML content predictions""" | |
| combined_probs = [ | |
| url_weight * url_probs[0] + html_weight * html_probs[0], # benign | |
| url_weight * url_probs[1] + html_weight * html_probs[1] # phishing | |
| ] | |
| return combined_probs | |
| # --------- Inference function ---------- | |
| def predict_fn(text: str): | |
| if not text or not text.strip(): | |
| return {"error": "Please enter a URL or text."}, "" | |
| # Check if input is URL | |
| if is_url(text.strip()): | |
| # Process URL | |
| url = text.strip() | |
| # Get prediction for URL itself | |
| url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL") | |
| # Try to fetch HTML content | |
| html_content, status = fetch_html_content(url) | |
| if html_content: | |
| # Get prediction for HTML content | |
| html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML") | |
| # Combine predictions | |
| combined_probs = combine_predictions(url_probs, html_probs) | |
| # Use combined probabilities but show analysis for both | |
| probs = combined_probs | |
| tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display | |
| has_attention = url_has_attention or html_has_attention | |
| attention_weights = url_attention if url_has_attention else html_attention | |
| analysis_type = "Combined URL + HTML Analysis" | |
| fetch_status = f"✅ Successfully fetched HTML content (Status: {status})" | |
| else: | |
| # Fallback to URL-only analysis | |
| probs = url_probs | |
| tokens = url_tokens | |
| has_attention = url_has_attention | |
| attention_weights = url_attention | |
| analysis_type = "URL-only Analysis" | |
| fetch_status = f"⚠️ Could not fetch HTML content: {status}" | |
| else: | |
| # Process as regular text | |
| probs, tokens, has_attention, attention_weights = predict_single_text(text, "text") | |
| analysis_type = "Text Analysis" | |
| fetch_status = "" | |
| # Get tokens for visualization | |
| # Create detailed analysis | |
| predicted_class = "phishing" if probs[1] > probs[0] else "benign" | |
| confidence = max(probs) | |
| detailed_analysis = f""" | |
| <div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; background: #1e1e1e; padding: 20px; border-radius: 15px;"> | |
| <div style="background: linear-gradient(135deg, {'#8b0000' if predicted_class == 'phishing' else '#006400'} 0%, {'#dc143c' if predicted_class == 'phishing' else '#228b22'} 100%); padding: 25px; border-radius: 20px; color: white; text-align: center; margin-bottom: 20px; box-shadow: 0 8px 32px rgba(0,0,0,0.5); border: 2px solid {'#ff4444' if predicted_class == 'phishing' else '#44ff44'};"> | |
| <h2 style="margin: 0 0 10px 0; font-size: 28px; color: white;">🔍 {analysis_type}</h2> | |
| <div style="font-size: 36px; font-weight: bold; margin: 10px 0; color: white;"> | |
| {predicted_class.upper()} | |
| </div> | |
| <div style="font-size: 18px; color: #f0f0f0;"> | |
| Confidence: {confidence:.1%} | |
| </div> | |
| <div style="margin-top: 15px; font-size: 14px; color: #e0e0e0;"> | |
| {'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'} | |
| </div> | |
| </div> | |
| """ | |
| if fetch_status: | |
| detailed_analysis += f""" | |
| <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
| <strong>Fetch Status:</strong> {fetch_status} | |
| </div> | |
| """ | |
| if has_attention and attention_weights is not None: | |
| attention_scores = attention_weights.squeeze(0).tolist() | |
| token_analysis = [] | |
| for i, (token, score) in enumerate(zip(tokens, attention_scores)): | |
| # More lenient filtering - include more tokens for text analysis | |
| if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005: | |
| clean_token = token.replace('▁', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes | |
| if clean_token: # Only add if token has content after cleaning | |
| token_analysis.append({ | |
| 'token': clean_token, | |
| 'importance': score, | |
| 'position': i | |
| }) | |
| # Sort by importance | |
| token_analysis.sort(key=lambda x: x['importance'], reverse=True) | |
| detailed_analysis += f""" | |
| ## Top important tokens: | |
| <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
| <strong>Analysis Info:</strong> Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens | |
| </div> | |
| <div style="font-family: Arial, sans-serif;"> | |
| """ | |
| for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens | |
| bar_width = int(token_info['importance'] * 100) | |
| color = "#ff4444" if predicted_class == "phishing" else "#44ff44" | |
| detailed_analysis += f""" | |
| <div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};"> | |
| <div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;"> | |
| {i+1}. | |
| </div> | |
| <div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;"> | |
| {token_info['token']} | |
| </div> | |
| <div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;"> | |
| <div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px; transition: width 0.3s ease;"></div> | |
| </div> | |
| <div style="color: #cccccc; font-size: 12px; font-weight: bold;"> | |
| {token_info['importance']:.1%} | |
| </div> | |
| </div> | |
| """ | |
| detailed_analysis += "</div>\n" | |
| detailed_analysis += f""" | |
| ## Detailed analysis: | |
| <div style="font-family: Arial, sans-serif; background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
| <h3 style="margin: 0 0 15px 0; color: white;">Statistical Overview</h3> | |
| <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;"> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Total tokens</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold, color: white;">{len([t for t in token_analysis if t['importance'] > 0.05])}</div> | |
| <div style="font-size: 14px, color: #e0e0e0;">High impact tokens (>5%)</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;"> | |
| <h3 style="color: #ffffff; margin-bottom: 15px;"> Prediction Confidence</h3> | |
| <div style="display: flex; justify-content: space-between; margin-bottom: 10px;"> | |
| <span style="font-weight: bold; color: #ff4444;">Phishing</span> | |
| <span style="font-weight: bold; color: #44ff44;">Benign</span> | |
| </div> | |
| <div style="width: 100%; background-color: #404040; border-radius: 25px; overflow: hidden; height: 30px; border: 1px solid #666;"> | |
| <div style="width: {probs[1]*100:.1f}%; background: linear-gradient(90deg, #ff4444 0%, #ff6666 100%); height: 100%; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 14px;"> | |
| {probs[1]:.1%} | |
| </div> | |
| </div> | |
| <div style="margin-top: 10px; text-align: center; color: #cccccc; font-size: 14px;"> | |
| Benign: {probs[0]:.1%} | |
| </div> | |
| </div> | |
| """ | |
| else: | |
| # Fallback analysis without attention weights | |
| detailed_analysis += f""" | |
| <div style="background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
| <h3 style="margin: 0 0 15px 0; color: white;">Basic Analysis</h3> | |
| <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px;"> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{probs[1]:.1%}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Phishing</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{probs[0]:.1%}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Benign</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Tokens</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="background: #2d2d2d; padding: 20px; border-radius: 15px; margin: 15px 0; border: 1px solid #555;"> | |
| <h3 style="color: #ffffff; margin: 0 0 15px 0;">🔤 Tokens in text:</h3> | |
| <div style="display: flex; flex-wrap: wrap; gap: 8px;">""" + ''.join([f'<span style="background: #404040; color: #64b5f6; padding: 4px 8px; border-radius: 15px; font-size: 12px; border: 1px solid #666;">{token.replace("▁", "")}</span>' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""</div> | |
| <div style="margin-top: 15px; padding: 10px; background: #3d2914; border-radius: 8px; border-left: 4px solid #ff9800;"> | |
| <strong style="color: #ffcc02;">Debug info:</strong> <span style="color: #e0e0e0;">Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens</span> | |
| </div> | |
| </div> | |
| <div style="background: #3d2914; padding: 15px; border-radius: 10px; border-left: 4px solid #ff9800; margin: 15px 0;"> | |
| <p style="margin: 0; color: #ffcc02; font-size: 14px;"> | |
| <strong>Note:</strong> Detailed attention weights analysis is not available for the current model. | |
| </p> | |
| </div> | |
| """ | |
| # Build label->prob mapping for Gradio Label output | |
| if len(LABELS) == len(probs): | |
| prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} | |
| else: | |
| prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)} | |
| return prediction_result, detailed_analysis | |
| # --------- BERT Model Functions ---------- | |
| def predict_bert_single_text(text, text_type="text"): | |
| """Predict for a single text input using BERT.""" | |
| # Tokenize | |
| inputs = bert_tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=512 | |
| ) | |
| # Move to device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = bert_model(**inputs, output_attentions=True) | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=-1).squeeze(0).tolist() | |
| # Get tokens for visualization | |
| tokens = bert_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist()) | |
| # Get attention weights (use last layer, first head as approximation) | |
| attention_weights = None | |
| has_attention = False | |
| if hasattr(outputs, 'attentions') and outputs.attentions is not None: | |
| # Average attention across all heads in the last layer | |
| last_layer_attention = outputs.attentions[-1] # Last layer | |
| attention_weights = last_layer_attention.mean(dim=1).squeeze(0) # Average across heads | |
| # Use attention to [CLS] token as importance scores | |
| attention_weights = attention_weights[0] # [CLS] token attention to all tokens | |
| has_attention = True | |
| return probs, tokens, has_attention, attention_weights | |
| def predict_bert_interface_fn(text: str): | |
| """Gradio interface function for BERT model.""" | |
| if not text or not text.strip(): | |
| return {"error": "Please enter a URL or text."}, "" | |
| # Check if input is URL | |
| if is_url(text.strip()): | |
| # Process URL | |
| url = text.strip() | |
| # Get prediction for URL itself | |
| url_probs, url_tokens, url_has_attention, url_attention = predict_bert_single_text(url, "URL") | |
| # Try to fetch HTML content | |
| html_content, status = fetch_html_content(url) | |
| if html_content: | |
| # Get prediction for HTML content | |
| html_probs, html_tokens, html_has_attention, html_attention = predict_bert_single_text(html_content, "HTML") | |
| # Combine predictions | |
| combined_probs = combine_predictions(url_probs, html_probs) | |
| # Use combined probabilities but show analysis for both | |
| probs = combined_probs | |
| tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display | |
| has_attention = url_has_attention or html_has_attention | |
| attention_weights = url_attention if url_has_attention else html_attention | |
| analysis_type = "Combined URL + HTML BERT Analysis" | |
| fetch_status = f"✅ Successfully fetched HTML content (Status: {status})" | |
| else: | |
| # Fallback to URL-only analysis | |
| probs = url_probs | |
| tokens = url_tokens | |
| has_attention = url_has_attention | |
| attention_weights = url_attention | |
| analysis_type = "URL-only BERT Analysis" | |
| fetch_status = f"⚠️ Could not fetch HTML content: {status}" | |
| else: | |
| # Process as regular text | |
| probs, tokens, has_attention, attention_weights = predict_bert_single_text(text, "text") | |
| analysis_type = "BERT Text Analysis" | |
| fetch_status = "" | |
| # Create detailed analysis | |
| predicted_class = "phishing" if probs[1] > probs[0] else "benign" | |
| confidence = max(probs) | |
| detailed_analysis = f""" | |
| <div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; background: #1e1e1e; padding: 20px; border-radius: 15px;"> | |
| <div style="background: linear-gradient(135deg, {'#8b0000' if predicted_class == 'phishing' else '#006400'} 0%, {'#dc143c' if predicted_class == 'phishing' else '#228b22'} 100%); padding: 25px; border-radius: 20px; color: white; text-align: center; margin-bottom: 20px; box-shadow: 0 8px 32px rgba(0,0,0,0.5); border: 2px solid {'#ff4444' if predicted_class == 'phishing' else '#44ff44'};"> | |
| <h2 style="margin: 0 0 10px 0; font-size: 28px; color: white;">🔍 {analysis_type}</h2> | |
| <div style="font-size: 36px; font-weight: bold; margin: 10px 0; color: white;"> | |
| {predicted_class.upper()} | |
| </div> | |
| <div style="font-size: 18px; color: #f0f0f0;"> | |
| Confidence: {confidence:.1%} | |
| </div> | |
| <div style="margin-top: 15px; font-size: 14px; color: #e0e0e0;"> | |
| {'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'} | |
| </div> | |
| </div> | |
| """ | |
| if fetch_status: | |
| detailed_analysis += f""" | |
| <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
| <strong>Fetch Status:</strong> {fetch_status} | |
| </div> | |
| """ | |
| if has_attention and attention_weights is not None: | |
| attention_scores = attention_weights.squeeze(0).tolist() if attention_weights.dim() > 1 else attention_weights.tolist() | |
| token_analysis = [] | |
| for i, (token, score) in enumerate(zip(tokens, attention_scores)): | |
| # More lenient filtering - include more tokens for text analysis | |
| if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005: | |
| clean_token = token.replace('▁', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes | |
| if clean_token: # Only add if token has content after cleaning | |
| token_analysis.append({ | |
| 'token': clean_token, | |
| 'importance': score, | |
| 'position': i | |
| }) | |
| # Sort by importance | |
| token_analysis.sort(key=lambda x: x['importance'], reverse=True) | |
| detailed_analysis += f""" | |
| ## Top important tokens: | |
| <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
| <strong>Analysis Info:</strong> Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens | |
| </div> | |
| <div style="font-family: Arial, sans-serif;"> | |
| """ | |
| for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens | |
| bar_width = int(token_info['importance'] * 100) | |
| color = "#ff4444" if predicted_class == "phishing" else "#44ff44" | |
| detailed_analysis += f""" | |
| <div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};"> | |
| <div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;"> | |
| {i+1}. | |
| </div> | |
| <div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;"> | |
| {token_info['token']} | |
| </div> | |
| <div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;"> | |
| <div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px; transition: width 0.3s ease;"></div> | |
| </div> | |
| <div style="color: #cccccc; font-size: 12px; font-weight: bold;"> | |
| {token_info['importance']:.1%} | |
| </div> | |
| </div> | |
| """ | |
| detailed_analysis += "</div>\n" | |
| detailed_analysis += f""" | |
| ## Detailed analysis: | |
| <div style="font-family: Arial, sans-serif; background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
| <h3 style="margin: 0 0 15px 0; color: white;">Statistical Overview</h3> | |
| <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;"> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Total tokens</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in token_analysis if t['importance'] > 0.05])}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">High impact tokens (>5%)</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;"> | |
| <h3 style="color: #ffffff; margin-bottom: 15px;"> Prediction Confidence</h3> | |
| <div style="display: flex; justify-content: space-between; margin-bottom: 10px;"> | |
| <span style="font-weight: bold; color: #ff4444;">Phishing</span> | |
| <span style="font-weight: bold; color: #44ff44;">Benign</span> | |
| </div> | |
| <div style="width: 100%; background-color: #404040; border-radius: 25px; overflow: hidden; height: 30px; border: 1px solid #666;"> | |
| <div style="width: {probs[1]*100:.1f}%; background: linear-gradient(90deg, #ff4444 0%, #ff6666 100%); height: 100%; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 14px;"> | |
| {probs[1]:.1%} | |
| </div> | |
| </div> | |
| <div style="margin-top: 10px; text-align: center; color: #cccccc; font-size: 14px;"> | |
| Benign: {probs[0]:.1%} | |
| </div> | |
| </div> | |
| """ | |
| else: | |
| # Fallback analysis without attention weights | |
| detailed_analysis += f""" | |
| <div style="background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
| <h3 style="margin: 0 0 15px 0; color: white;">Basic Analysis</h3> | |
| <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px;"> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{probs[1]:.1%}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Phishing</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{probs[0]:.1%}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Benign</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Tokens</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="background: #2d2d2d; padding: 20px; border-radius: 15px; margin: 15px 0; border: 1px solid #555;"> | |
| <h3 style="color: #ffffff; margin: 0 0 15px 0;">🔤 Tokens in text:</h3> | |
| <div style="display: flex; flex-wrap: wrap; gap: 8px;">""" + ''.join([f'<span style="background: #404040; color: #64b5f6; padding: 4px 8px; border-radius: 15px; font-size: 12px; border: 1px solid #666;">{token.replace("▁", "")}</span>' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""</div> | |
| <div style="margin-top: 15px; padding: 10px; background: #3d2914; border-radius: 8px; border-left: 4px solid #ff9800;"> | |
| <strong style="color: #ffcc02;">Debug info:</strong> <span style="color: #e0e0e0;">Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens</span> | |
| </div> | |
| </div> | |
| <div style="background: #3d2914; padding: 15px; border-radius: 10px; border-left: 4px solid #ff9800; margin: 15px 0;"> | |
| <p style="margin: 0; color: #ffcc02; font-size: 14px;"> | |
| <strong>Note:</strong> Detailed attention weights analysis is not available for the current model. | |
| </p> | |
| </div> | |
| """ | |
| detailed_analysis += "</div>" | |
| # Build label->prob mapping for Gradio Label output | |
| if len(LABELS) == len(probs): | |
| prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} | |
| else: | |
| prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)} | |
| return prediction_result, detailed_analysis | |
| # --------- Gradio UI ---------- | |
| deberta_interface = gr.Interface( | |
| fn=predict_fn, | |
| inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"), | |
| outputs=[ | |
| gr.Label(label="Prediction result"), | |
| gr.Markdown(label="Detailed token analysis") | |
| ], | |
| title="Phishing Detector (DeBERTa + LSTM)", | |
| description=""" | |
| Enter a URL or text for analysis. | |
| **Features:** | |
| - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis | |
| - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis | |
| - **Visual Analysis**: Predict phishing/benign probability with visual charts | |
| - **Token Importance**: Display the most important tokens in classification | |
| - **Detailed Insights**: Comprehensive analysis of the impact of each token | |
| - **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes | |
| **How it works for URLs:** | |
| 1. Analyze the URL structure itself | |
| 2. Fetch the webpage HTML content | |
| 3. Analyze the webpage content | |
| 4. Combine both results for final prediction (30% URL + 70% content) | |
| """, | |
| examples=[ | |
| ["http://rendmoiunserviceeee.com"], | |
| ["https://www.google.com"], | |
| ["Dear customer, your account has been suspended. Click here to verify your identity immediately."], | |
| ["https://mail-secure-login-verify.example/path?token=suspicious"], | |
| ["http://paypaI-security-update.net/login"], | |
| ["Your package has been delivered successfully. Thank you for using our service."], | |
| ["https://github.com/user/repo"] | |
| ], | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| background-color: #1e1e1e !important; | |
| color: #ffffff !important; | |
| } | |
| .dark .gradio-container { | |
| background-color: #1e1e1e !important; | |
| } | |
| /* Dark theme for all components */ | |
| .block { | |
| background-color: #2d2d2d !important; | |
| border: 1px solid #444 !important; | |
| } | |
| .gradio-textbox { | |
| background-color: #3d3d3d !important; | |
| color: #ffffff !important; | |
| border: 1px solid #666 !important; | |
| } | |
| .gradio-button { | |
| background-color: #4a4a4a !important; | |
| color: #ffffff !important; | |
| border: 1px solid #666 !important; | |
| } | |
| .gradio-button:hover { | |
| background-color: #5a5a5a !important; | |
| } | |
| """ | |
| ) | |
| bert_interface = gr.Interface( | |
| fn=predict_bert_interface_fn, | |
| inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"), | |
| outputs=[ | |
| gr.Label(label="Prediction result"), | |
| gr.Markdown(label="Detailed token analysis") | |
| ], | |
| title="Phishing Detector (BERT)", | |
| description=""" | |
| Enter a URL or text for analysis using the BERT model. | |
| **Features:** | |
| - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis | |
| - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis | |
| - **Visual Analysis**: Predict phishing/benign probability with visual charts | |
| - **Token Importance**: Display the most important tokens in classification using attention weights | |
| - **Detailed Insights**: Comprehensive analysis of the impact of each token | |
| - **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes | |
| **How it works for URLs:** | |
| 1. Analyze the URL structure itself | |
| 2. Fetch the webpage HTML content | |
| 3. Analyze the webpage content | |
| 4. Combine both results for final prediction (30% URL + 70% content) | |
| """, | |
| examples=[ | |
| ["http://rendmoiunserviceeee.com"], | |
| ["https://www.google.com"], | |
| ["Dear customer, your account has been suspended. Click here to verify your identity immediately."], | |
| ["https://mail-secure-login-verify.example/path?token=suspicious"], | |
| ["http://paypaI-security-update.net/login"], | |
| ["Your package has been delivered successfully. Thank you for using our service."], | |
| ["https://github.com/user/repo"] | |
| ], | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| background-color: #1e1e1e !important; | |
| color: #ffffff !important; | |
| } | |
| .dark .gradio-container { | |
| background-color: #1e1e1e !important; | |
| } | |
| /* Dark theme for all components */ | |
| .block { | |
| background-color: #2d2d2d !important; | |
| border: 1px solid #444 !important; | |
| } | |
| .gradio-textbox { | |
| background-color: #3d3d3d !important; | |
| color: #ffffff !important; | |
| border: 1px solid #666 !important; | |
| } | |
| .gradio-button { | |
| background-color: #4a4a4a !important; | |
| color: #ffffff !important; | |
| border: 1px solid #666 !important; | |
| } | |
| .gradio-button:hover { | |
| background-color: #5a5a5a !important; | |
| } | |
| """ | |
| ) | |
| demo = gr.TabbedInterface( | |
| [deberta_interface, bert_interface], | |
| ["DeBERTa + LSTM", "BERT"] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |