File size: 2,245 Bytes
aa971e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import gradio as gr
import re
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# Download NLTK resources (optional if already available)
nltk.download('punkt_tab')
nltk.download('stopwords')
nltk.download('wordnet')
# Preprocessing setup
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
# Remove non-alphabetic characters
text = re.sub(r'[^A-Za-z\s]', '', text)
# Remove URLs
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
# Remove extra spaces
text = re.sub(r'\s+', ' ', text).strip()
# Lowercase
text = text.lower()
# Tokenize
tokens = word_tokenize(text)
# Remove stopwords
tokens = [word for word in tokens if word not in stop_words]
# Lemmatize
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
# Load tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
# Load trained phishing detection model
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
model.eval()
# Label mapping
idx2label = {0: "phishing", 1: "legitimate"}
# Prediction function
def predict(text):
clean_text = preprocess_text(text)
inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0].numpy()
return {idx2label[i]: float(round(probs[i], 4)) for i in range(2)}
# Gradio UI
interface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=4, placeholder="Enter a suspicious message or account description..."),
outputs=gr.Label(num_top_classes=2),
title="🛡️ Phishing Account Detector",
description="Detects whether an account or message is likely phishing or legitimate using a custom DistilBERT model."
)
if __name__ == "__main__":
interface.launch()
|