ma4389 commited on
Commit
aa971e4
·
verified ·
1 Parent(s): e94132d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -67
app.py CHANGED
@@ -1,67 +1,67 @@
1
- import torch
2
- from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
3
- import gradio as gr
4
- import re
5
- import nltk
6
- from nltk.tokenize import word_tokenize
7
- from nltk.corpus import stopwords
8
- from nltk.stem import WordNetLemmatizer
9
-
10
- # Download NLTK resources (optional if already available)
11
- nltk.download('punkt_tab')
12
- nltk.download('stopwords')
13
- nltk.download('wordnet')
14
-
15
- # Preprocessing setup
16
- stop_words = set(stopwords.words('english'))
17
- lemmatizer = WordNetLemmatizer()
18
-
19
- def preprocess_text(text):
20
- # Remove non-alphabetic characters
21
- text = re.sub(r'[^A-Za-z\s]', '', text)
22
- # Remove URLs
23
- text = re.sub(r'http\S+|www\S+|https\S+', '', text)
24
- # Remove extra spaces
25
- text = re.sub(r'\s+', ' ', text).strip()
26
- # Lowercase
27
- text = text.lower()
28
- # Tokenize
29
- tokens = word_tokenize(text)
30
- # Remove stopwords
31
- tokens = [word for word in tokens if word not in stop_words]
32
- # Lemmatize
33
- tokens = [lemmatizer.lemmatize(word) for word in tokens]
34
- return ' '.join(tokens)
35
-
36
- # Load tokenizer and model
37
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
38
- model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
39
-
40
- # Load trained phishing detection model
41
- model.load_state_dict(torch.load("phishing_model.pth", map_location=torch.device("cpu")))
42
- model.eval()
43
-
44
- # Label mapping
45
- idx2label = {0: "phishing", 1: "legitimate"}
46
-
47
- # Prediction function
48
- def predict(text):
49
- clean_text = preprocess_text(text)
50
- inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
51
- with torch.no_grad():
52
- outputs = model(**inputs)
53
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0].numpy()
54
-
55
- return {idx2label[i]: float(round(probs[i], 4)) for i in range(2)}
56
-
57
- # Gradio UI
58
- interface = gr.Interface(
59
- fn=predict,
60
- inputs=gr.Textbox(lines=4, placeholder="Enter a suspicious message or account description..."),
61
- outputs=gr.Label(num_top_classes=2),
62
- title="🛡️ Phishing Account Detector",
63
- description="Detects whether an account or message is likely phishing or legitimate using a custom DistilBERT model."
64
- )
65
-
66
- if __name__ == "__main__":
67
- interface.launch()
 
1
+ import torch
2
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
3
+ import gradio as gr
4
+ import re
5
+ import nltk
6
+ from nltk.tokenize import word_tokenize
7
+ from nltk.corpus import stopwords
8
+ from nltk.stem import WordNetLemmatizer
9
+
10
+ # Download NLTK resources (optional if already available)
11
+ nltk.download('punkt_tab')
12
+ nltk.download('stopwords')
13
+ nltk.download('wordnet')
14
+
15
+ # Preprocessing setup
16
+ stop_words = set(stopwords.words('english'))
17
+ lemmatizer = WordNetLemmatizer()
18
+
19
+ def preprocess_text(text):
20
+ # Remove non-alphabetic characters
21
+ text = re.sub(r'[^A-Za-z\s]', '', text)
22
+ # Remove URLs
23
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text)
24
+ # Remove extra spaces
25
+ text = re.sub(r'\s+', ' ', text).strip()
26
+ # Lowercase
27
+ text = text.lower()
28
+ # Tokenize
29
+ tokens = word_tokenize(text)
30
+ # Remove stopwords
31
+ tokens = [word for word in tokens if word not in stop_words]
32
+ # Lemmatize
33
+ tokens = [lemmatizer.lemmatize(word) for word in tokens]
34
+ return ' '.join(tokens)
35
+
36
+ # Load tokenizer and model
37
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
38
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
39
+
40
+ # Load trained phishing detection model
41
+ model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
42
+ model.eval()
43
+
44
+ # Label mapping
45
+ idx2label = {0: "phishing", 1: "legitimate"}
46
+
47
+ # Prediction function
48
+ def predict(text):
49
+ clean_text = preprocess_text(text)
50
+ inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
53
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0].numpy()
54
+
55
+ return {idx2label[i]: float(round(probs[i], 4)) for i in range(2)}
56
+
57
+ # Gradio UI
58
+ interface = gr.Interface(
59
+ fn=predict,
60
+ inputs=gr.Textbox(lines=4, placeholder="Enter a suspicious message or account description..."),
61
+ outputs=gr.Label(num_top_classes=2),
62
+ title="🛡️ Phishing Account Detector",
63
+ description="Detects whether an account or message is likely phishing or legitimate using a custom DistilBERT model."
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ interface.launch()