|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
import numpy as np |
|
|
from .common import simple_text_clean, CLASS_NAMES |
|
|
import traceback |
|
|
import os |
|
|
from transformers_interpret import SequenceClassificationExplainer |
|
|
|
|
|
|
|
|
BERT_MODEL_ID = "lleratodev/720-bert-mini-phishing-fine-tune" |
|
|
bert_mini_tokenizer = None |
|
|
bert_mini_model = None |
|
|
device = None |
|
|
cls_explainer_bert_mini = None |
|
|
|
|
|
try: |
|
|
bert_mini_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_ID) |
|
|
bert_mini_model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_ID) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
bert_mini_model.to(device) |
|
|
bert_mini_model.eval() |
|
|
cls_explainer_bert_mini = SequenceClassificationExplainer(bert_mini_model, bert_mini_tokenizer) |
|
|
|
|
|
print(f"BERT-mini model ('{BERT_MODEL_ID}'), tokenizer, and Transformers-Interpret Explainer loaded successfully.") |
|
|
print(f"BERT-mini model running on device: {device}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"FATAL ERROR (BERT): Could not load model/tokenizer '{BERT_MODEL_ID}' or initialize Transformers-Interpret Explainer: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
def bert_mini_predict_probability_for_lime(text_instances: list) -> np.ndarray: |
|
|
if bert_mini_tokenizer is None or bert_mini_model is None: |
|
|
|
|
|
return np.array([[1.0/len(CLASS_NAMES)] * len(CLASS_NAMES)] * len(text_instances)) |
|
|
|
|
|
all_probabilities = [] |
|
|
try: |
|
|
for text_instance in text_instances: |
|
|
inputs = bert_mini_tokenizer(text_instance, return_tensors="pt", truncation=True, padding="max_length", max_length=512) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = bert_mini_model(**inputs) |
|
|
logits = outputs.logits |
|
|
|
|
|
probabilities_tensor = torch.softmax(logits, dim=-1) |
|
|
|
|
|
probabilities_for_instance = probabilities_tensor.cpu().numpy().squeeze() |
|
|
all_probabilities.append(probabilities_for_instance) |
|
|
|
|
|
return np.array(all_probabilities) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in bert_mini_predict_probability_for_lime: {e}") |
|
|
traceback.print_exc() |
|
|
return np.array([[1.0/len(CLASS_NAMES)] * len(CLASS_NAMES)] * len(text_instances)) |
|
|
|
|
|
|
|
|
def get_prediction_and_explanation_bert_mini(subject: str, sender: str, body: str) -> dict: |
|
|
if bert_mini_tokenizer is None or bert_mini_model is None or cls_explainer_bert_mini is None: |
|
|
return {"error": "BERT-Mini Model/Tokenizer/Explainer not loaded correctly. Check server logs.", |
|
|
"prediction": "Error", "label": -1, "confidence": 0.0, "explanation": []} |
|
|
|
|
|
cleaned_sender = simple_text_clean(sender) |
|
|
cleaned_subject = simple_text_clean(subject) |
|
|
cleaned_body = simple_text_clean(body) |
|
|
|
|
|
combined_text_for_prediction = f"{cleaned_sender} {cleaned_subject} {cleaned_body}" |
|
|
|
|
|
|
|
|
try: |
|
|
inputs = bert_mini_tokenizer(combined_text_for_prediction, return_tensors="pt", truncation=True, padding="max_length", max_length=512) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = bert_mini_model(**inputs) |
|
|
logits = outputs.logits |
|
|
|
|
|
probabilities_tensor = torch.softmax(logits, dim=-1) |
|
|
probabilities = probabilities_tensor.cpu().numpy()[0] |
|
|
|
|
|
prediction_label_int = np.argmax(probabilities).item() |
|
|
confidence_score = probabilities[prediction_label_int].item() |
|
|
predicted_class_name = CLASS_NAMES[prediction_label_int] |
|
|
|
|
|
explanation_data = [] |
|
|
try: |
|
|
word_attributions = cls_explainer_bert_mini( |
|
|
combined_text_for_prediction, |
|
|
index = prediction_label_int |
|
|
) |
|
|
|
|
|
explanation_data = [(word, float(score)) for word, score in word_attributions] |
|
|
|
|
|
explanation_data.sort(key=lambda x: abs(x[1]), reverse=True) |
|
|
explanation_data = explanation_data[:15] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Transformers-Interpret explanation error: {e}") |
|
|
traceback.print_exc() |
|
|
explanation_data = [("Explanation error with Transformers-Interpret", 0.0)] |
|
|
|
|
|
|
|
|
return { |
|
|
"prediction": predicted_class_name, |
|
|
"label": int(prediction_label_int), |
|
|
"confidence": float(confidence_score), |
|
|
"explanation": explanation_data, |
|
|
"error": None |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"--- ORIGINAL ERROR in predict_with_bert_mini ---") |
|
|
print(f"Error type: {type(e)}") |
|
|
print(f"Error message: {str(e)}") |
|
|
print("Traceback:") |
|
|
traceback.print_exc() |
|
|
return {"error": f"BERT-Mini Prediction error: {str(e)}", "prediction": "Error", |
|
|
"label": -1, "confidence": 0.0, "explanation": []} |