Gaykar commited on
Commit
6da1ce8
·
1 Parent(s): c69597c

added models

Browse files
Files changed (1) hide show
  1. preditormodels.py +63 -0
preditormodels.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
2
+ import torch
3
+ from config import URL_FEATURES ,device
4
+ import numpy as np
5
+ from pipeline import EmailFeatureExtractor
6
+ import joblib
7
+ import xgboost as xgb
8
+ class PhishingPredictor:
9
+ def __init__(self, bert_path: str, xgb_path: str):
10
+ print("[INFO] Initializing Models...")
11
+ self.device = device
12
+
13
+ # 1. Load BERT components
14
+ self.tokenizer = DistilBertTokenizerFast.from_pretrained(bert_path)
15
+ self.bert_model = DistilBertForSequenceClassification.from_pretrained(bert_path)
16
+ self.bert_model.to(self.device)
17
+ self.bert_model.eval()
18
+
19
+ # 2. Load XGBoost Classifier
20
+ # Use load_model for .json/.model or joblib for .pkl
21
+
22
+ self.xgb_model = xgb.XGBClassifier()
23
+ self.xgb_model.load_model(xgb_path)
24
+
25
+ # 3. Initialize your Feature Extractor
26
+ self.extractor = EmailFeatureExtractor()
27
+
28
+ def get_cls_embedding(self, text: str) -> np.ndarray:
29
+ """Generates 768-dim CLS embedding from fine-tuned BERT."""
30
+ with torch.no_grad():
31
+ inputs = self.tokenizer(
32
+ text, return_tensors="pt", truncation=True, padding=True, max_length=256
33
+ ).to(self.device)
34
+
35
+ outputs = self.bert_model.distilbert(**inputs)
36
+ # Take CLS token embedding
37
+ return outputs.last_hidden_state[:, 0, :].cpu().numpy()
38
+
39
+ def predict(self, subject: str, body: str):
40
+ # Step 1: Extract all features using your pipeline
41
+ processed_df = self.extractor.transform(subject, body)
42
+
43
+ # Step 2: Get BERT Embeddings for text_combined
44
+ bert_emb = self.get_cls_embedding(processed_df['text_combined'].iloc[0])
45
+
46
+ # Step 3: Get Numerical features (the 19 URL features)
47
+ url_feats = processed_df[URL_FEATURES].to_numpy(dtype=np.float32)
48
+
49
+ # Step 4: Concatenate [BERT (768) + URL (19)] = 787 Features
50
+ final_input = np.concatenate([bert_emb, url_feats], axis=1)
51
+
52
+
53
+ prob = self.xgb_model.predict_proba(final_input)[0][1]
54
+
55
+ prediction = "PHISHING" if prob > 0.5 else "SAFE"
56
+
57
+ return {
58
+ "prediction": prediction,
59
+ "confidence": f"{prob*100:.2f}%",
60
+ "url_count": int(processed_df['URL_COUNT'].iloc[0])
61
+ }
62
+
63
+