LT360
Final update
47e416f
# app/ml/bert_mini_model.py
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
from .common import simple_text_clean, CLASS_NAMES
import traceback
import os
from transformers_interpret import SequenceClassificationExplainer
# # Load BERT-mini model and tokenizer from Hugging Face Hub
BERT_MODEL_ID = "lleratodev/720-bert-mini-phishing-fine-tune"
bert_mini_tokenizer = None
bert_mini_model = None
device = None
cls_explainer_bert_mini = None
try:
bert_mini_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_ID)
bert_mini_model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_mini_model.to(device)
bert_mini_model.eval()
cls_explainer_bert_mini = SequenceClassificationExplainer(bert_mini_model, bert_mini_tokenizer)
print(f"BERT-mini model ('{BERT_MODEL_ID}'), tokenizer, and Transformers-Interpret Explainer loaded successfully.")
print(f"BERT-mini model running on device: {device}")
except Exception as e:
print(f"FATAL ERROR (BERT): Could not load model/tokenizer '{BERT_MODEL_ID}' or initialize Transformers-Interpret Explainer: {e}")
traceback.print_exc()
# # Using BERT-Mini model to make email classifications
def bert_mini_predict_probability_for_lime(text_instances: list) -> np.ndarray:
if bert_mini_tokenizer is None or bert_mini_model is None:
# Return neutral probabilities if model isn't loaded (number of instances, number of classes)
return np.array([[1.0/len(CLASS_NAMES)] * len(CLASS_NAMES)] * len(text_instances))
all_probabilities = []
try:
for text_instance in text_instances:
inputs = bert_mini_tokenizer(text_instance, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = bert_mini_model(**inputs)
logits = outputs.logits
probabilities_tensor = torch.softmax(logits, dim=-1)
probabilities_for_instance = probabilities_tensor.cpu().numpy().squeeze()
all_probabilities.append(probabilities_for_instance)
return np.array(all_probabilities)
except Exception as e:
print(f"Error in bert_mini_predict_probability_for_lime: {e}")
traceback.print_exc()
return np.array([[1.0/len(CLASS_NAMES)] * len(CLASS_NAMES)] * len(text_instances))
def get_prediction_and_explanation_bert_mini(subject: str, sender: str, body: str) -> dict:
if bert_mini_tokenizer is None or bert_mini_model is None or cls_explainer_bert_mini is None:
return {"error": "BERT-Mini Model/Tokenizer/Explainer not loaded correctly. Check server logs.",
"prediction": "Error", "label": -1, "confidence": 0.0, "explanation": []}
cleaned_sender = simple_text_clean(sender)
cleaned_subject = simple_text_clean(subject)
cleaned_body = simple_text_clean(body)
combined_text_for_prediction = f"{cleaned_sender} {cleaned_subject} {cleaned_body}"
try:
inputs = bert_mini_tokenizer(combined_text_for_prediction, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = bert_mini_model(**inputs)
logits = outputs.logits
probabilities_tensor = torch.softmax(logits, dim=-1)
probabilities = probabilities_tensor.cpu().numpy()[0]
prediction_label_int = np.argmax(probabilities).item()
confidence_score = probabilities[prediction_label_int].item()
predicted_class_name = CLASS_NAMES[prediction_label_int]
explanation_data = []
try:
word_attributions = cls_explainer_bert_mini(
combined_text_for_prediction,
index = prediction_label_int
)
explanation_data = [(word, float(score)) for word, score in word_attributions]
explanation_data.sort(key=lambda x: abs(x[1]), reverse=True)
explanation_data = explanation_data[:15]
except Exception as e:
print(f"Transformers-Interpret explanation error: {e}")
traceback.print_exc()
explanation_data = [("Explanation error with Transformers-Interpret", 0.0)]
# --- End Explanation ---
return {
"prediction": predicted_class_name,
"label": int(prediction_label_int),
"confidence": float(confidence_score),
"explanation": explanation_data,
"error": None
}
except Exception as e:
print(f"--- ORIGINAL ERROR in predict_with_bert_mini ---")
print(f"Error type: {type(e)}")
print(f"Error message: {str(e)}")
print("Traceback:")
traceback.print_exc()
return {"error": f"BERT-Mini Prediction error: {str(e)}", "prediction": "Error",
"label": -1, "confidence": 0.0, "explanation": []}