import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import gradio as gr import re import nltk from nltk.tokenize import word_tokenize from nltk.corpus import stopwords from nltk.stem import WordNetLemmatizer # Download NLTK resources (optional if already available) nltk.download('punkt_tab') nltk.download('stopwords') nltk.download('wordnet') # Preprocessing setup stop_words = set(stopwords.words('english')) lemmatizer = WordNetLemmatizer() def preprocess_text(text): # Remove non-alphabetic characters text = re.sub(r'[^A-Za-z\s]', '', text) # Remove URLs text = re.sub(r'http\S+|www\S+|https\S+', '', text) # Remove extra spaces text = re.sub(r'\s+', ' ', text).strip() # Lowercase text = text.lower() # Tokenize tokens = word_tokenize(text) # Remove stopwords tokens = [word for word in tokens if word not in stop_words] # Lemmatize tokens = [lemmatizer.lemmatize(word) for word in tokens] return ' '.join(tokens) # Load tokenizer and model tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) # Load trained phishing detection model model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu"))) model.eval() # Label mapping idx2label = {0: "phishing", 1: "legitimate"} # Prediction function def predict(text): clean_text = preprocess_text(text) inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0].numpy() return {idx2label[i]: float(round(probs[i], 4)) for i in range(2)} # Gradio UI interface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=4, placeholder="Enter a suspicious message or account description..."), outputs=gr.Label(num_top_classes=2), title="🛡️ Phishing Account Detector", description="Detects whether an account or message is likely phishing or legitimate using a custom DistilBERT model." ) if __name__ == "__main__": interface.launch()