Spaces:
Sleeping
Sleeping
| from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification | |
| import torch | |
| from config import URL_FEATURES ,device | |
| import numpy as np | |
| from pipeline import EmailFeatureExtractor | |
| import joblib | |
| import xgboost as xgb | |
| class PhishingPredictor: | |
| def __init__(self, bert_path: str, xgb_path: str): | |
| print("[INFO] Initializing Models...") | |
| self.device = device | |
| # 1. Load BERT components | |
| self.tokenizer = DistilBertTokenizerFast.from_pretrained(bert_path) | |
| self.bert_model = DistilBertForSequenceClassification.from_pretrained(bert_path) | |
| self.bert_model.to(self.device) | |
| self.bert_model.eval() | |
| # 2. Load XGBoost Classifier | |
| # Use load_model for .json/.model or joblib for .pkl | |
| self.xgb_model = xgb.XGBClassifier() | |
| self.xgb_model.load_model(xgb_path) | |
| # 3. Initialize your Feature Extractor | |
| self.extractor = EmailFeatureExtractor() | |
| def get_cls_embedding(self, text: str) -> np.ndarray: | |
| """Generates 768-dim CLS embedding from fine-tuned BERT.""" | |
| with torch.no_grad(): | |
| inputs = self.tokenizer( | |
| text, return_tensors="pt", truncation=True, padding=True, max_length=256 | |
| ).to(self.device) | |
| outputs = self.bert_model.distilbert(**inputs) | |
| # Take CLS token embedding | |
| return outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
| def predict(self, subject: str, body: str): | |
| # Step 1: Extract all features using your pipeline | |
| processed_df = self.extractor.transform(subject, body) | |
| # Step 2: Get BERT Embeddings for text_combined | |
| bert_emb = self.get_cls_embedding(processed_df['text_combined'].iloc[0]) | |
| # Step 3: Get Numerical features (the 19 URL features) | |
| url_feats = processed_df[URL_FEATURES].to_numpy(dtype=np.float32) | |
| # Step 4: Concatenate [BERT (768) + URL (19)] = 787 Features | |
| final_input = np.concatenate([bert_emb, url_feats], axis=1) | |
| prob = self.xgb_model.predict_proba(final_input)[0][1] | |
| prediction = "PHISHING" if prob > 0.5 else "SAFE" | |
| return { | |
| "prediction": prediction, | |
| "confidence": f"{prob*100:.2f}%", | |
| "url_count": int(processed_df['URL_COUNT'].iloc[0]) | |
| } | |