Spaces:
Running
Running
File size: 12,526 Bytes
b8630cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, pipeline
import torch
import torch.nn.functional as F
import os
import pickle
import re
class ContentAnalysisAgent:
def __init__(self):
# Detection of Device
self.device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {self.device} for inference optimization.")
self.model_name = "microsoft/deberta-v3-small"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
num_labels=2,
ignore_mismatched_sizes=True
).to(self.device)
# New: sentence-transformers/all-MiniLM-L6-v2 using AutoModel/AutoTokenizer
self.minilm_name = "sentence-transformers/all-MiniLM-L6-v2"
self.minilm_tokenizer = AutoTokenizer.from_pretrained(self.minilm_name)
self.minilm_model = AutoModel.from_pretrained(self.minilm_name).to(self.device)
# Optimization: Use Half-precision if on MPS
if self.device.type == "mps":
self.model = self.model.half()
self.minilm_model = self.minilm_model.half()
self.model.eval()
self.minilm_model.eval()
print("Loading Hugging Face pipelines...")
try:
self.mask_pipeline = pipeline("fill-mask", model="microsoft/deberta-v3-small")
self.sentiment_pipeline = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
self.has_pipelines = True
print("Successfully loaded HF pipelines.")
except Exception as e:
print(f"Failed to load HF pipelines: {e}")
self.has_pipelines = False
print("Loading local text ML models...")
model_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'models')
try:
with open(os.path.join(model_dir, 'model_new.pkl'), 'rb') as f:
self.scikit_model = pickle.load(f)
with open(os.path.join(model_dir, 'vectorizer_new.pkl'), 'rb') as f:
self.scikit_vectorizer = pickle.load(f)
self.has_text_ml = True
print("Successfully loaded text ML models.")
except Exception as e:
print(f"Failed to load text ML models: {e}")
self.has_text_ml = False
self.phishing_keywords = [
'verify', 'account', 'bank', 'login', 'password', 'credit card',
'ssn', 'social security', 'suspended', 'limited', 'unusual activity',
'confirm identity', 'update information', 'click here', 'urgent'
]
self.urgency_phrases = [
'immediately', 'within 24 hours', 'as soon as possible',
'urgent', 'action required', 'deadline', 'expire soon'
]
self.prompt_injection_patterns = [
'ignore previous instructions',
'ignore all previous',
'disregard previous',
'system prompt',
'you are now',
'act as',
'new role:',
'forget your instructions'
]
def analyze_phishing(self, text):
"""Analyze text for phishing indicators"""
text_lower = text.lower()
keyword_matches = []
for keyword in self.phishing_keywords:
if keyword in text_lower:
keyword_matches.append(keyword)
urgency_matches = []
for phrase in self.urgency_phrases:
if phrase in text_lower:
urgency_matches.append(phrase)
keyword_score = min(len(keyword_matches) / 5, 1.0)
urgency_score = min(len(urgency_matches) / 3, 1.0)
has_personal_info_request = any([
'password' in text_lower and 'send' in text_lower,
'credit card' in text_lower,
'ssn' in text_lower,
'social security' in text_lower
])
if has_personal_info_request:
personal_info_score = 0.8
else:
personal_info_score = 0.0
phishing_score = (keyword_score * 0.4 + urgency_score * 0.3 + personal_info_score * 0.3)
return phishing_score, keyword_matches, urgency_matches
def analyze_prompt_injection(self, text):
"""Check for prompt injection attempts"""
text_lower = text.lower()
for pattern in self.prompt_injection_patterns:
if pattern in text_lower:
return True, [f"Prompt injection pattern detected: '{pattern}'"]
return False, []
def analyze_ai_generated(self, text):
"""Basic detection of AI-generated content patterns"""
ai_indicators = [
'as an ai', 'i am an ai', 'as a language model',
'i cannot', 'i apologize', 'i am unable to',
'unfortunately', 'i must inform you'
]
text_lower = text.lower()
matches = [ind for ind in ai_indicators if ind in text_lower]
if len(matches) > 1:
return 0.7, matches
elif len(matches) > 0:
return 0.4, matches
else:
return 0.0, []
def analyze_with_transformer(self, text):
"""Use transformer model for classification with optimized inference"""
try:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
with torch.inference_mode(): # Faster than no_grad
outputs = self.model(**inputs)
probabilities = F.softmax(outputs.logits.float(), dim=-1) # Cast back to float for softmax
phishing_prob = probabilities[0][1].item()
return phishing_prob
except Exception as e:
print(f"Transformer error: {e}")
return 0.5
def get_minilm_embeddings(self, text):
"""Get embeddings using all-MiniLM-L6-v2 with mean pooling (optimized)"""
inputs = self.minilm_tokenizer(text, padding=True, truncation=True, return_tensors='pt', max_length=512).to(self.device)
with torch.inference_mode():
model_output = self.minilm_model(**inputs)
# Mean Pooling
attention_mask = inputs['attention_mask']
token_embeddings = model_output[0].float() # Cast to float16 to float32 for pooling stability
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return embeddings
def analyze_connection(self, text, urls):
"""Analyze the connection between email text and URLs"""
if not urls:
return 1.0, "No URLs to analyze"
text_emb = self.get_minilm_embeddings(text)
connection_scores = []
for url in urls:
# Extract meaningful parts of the URL for semantic comparison
url_parts = url.replace('http://', '').replace('https://', '').replace('www.', '')
url_parts = re.sub(r'[/.\-_]', ' ', url_parts)
url_emb = self.get_minilm_embeddings(url_parts)
similarity = F.cosine_similarity(text_emb, url_emb).item()
connection_scores.append(similarity)
avg_connection = sum(connection_scores) / len(connection_scores)
# A very low connection score (divergence) is an indicator of phishing
if avg_connection < 0.2:
return avg_connection, "High divergence: URL content does not match email context"
elif avg_connection < 0.4:
return avg_connection, "Moderate divergence: URL seems loosely related to email context"
else:
return avg_connection, "Stable: URL matches email context"
def analyze(self, input_data):
"""Main analysis function with hybrid and connection logic"""
text = input_data['cleaned_text']
urls = input_data['urls']
# Benign baseline check for short / common messages
benign_greetings = ['hi', 'hii', 'hiii', 'hello', 'hey', 'how are you', 'how is this', 'test']
clean_msg = text.lower().strip().replace('?', '').replace('!', '')
if clean_msg in benign_greetings and not urls:
return {
'phishing_probability': 0.01,
'urgency_matches': [],
'keyword_matches': [],
'prompt_injection': False,
'ai_generated_probability': 0.05,
'spam_probability': 0.01,
'connection_score': 1.0,
'connection_message': "Safe: Benign conversational text",
'sentiment_label': "POSITIVE",
'sentiment_score': 0.99
}
phishing_score, keyword_matches, urgency_matches = self.analyze_phishing(text)
prompt_injection, injection_patterns = self.analyze_prompt_injection(text)
ai_generated_score, ai_patterns = self.analyze_ai_generated(text)
transformer_score = self.analyze_with_transformer(text)
# Hybrid Text Analysis: Combine model.pkl score with transformer_score
spam_probability = 0.0
spam_ml_prob = 0.0
if self.has_text_ml:
try:
features = self.scikit_vectorizer.transform([text])
spam_ml_prob = self.scikit_model.predict_proba(features)[0][1]
# Fine-tune transformer score using the pickle model baseline
transformer_score = (transformer_score * 0.7) + (spam_ml_prob * 0.3)
spam_probability = spam_ml_prob
except Exception as e:
print(f"Text ML model unavailable (sklearn version mismatch), using fallback: {e}")
self.has_text_ml = False # disable to avoid repeated errors
# Connection Analysis
connection_score, connection_msg = self.analyze_connection(text, urls)
# Adjust combined phishing score based on connection divergence
# If divergence is high (low connection), we increase the phishing probability
connection_penalty = max(0, 0.5 - connection_score) if connection_score < 0.4 else 0
combined_phishing = min(max(phishing_score, transformer_score) + connection_penalty, 1.0)
if spam_probability < 0.3:
spam_indicators = ['free', 'win', 'winner', 'prize', 'click here', 'offer', 'limited time', 'lottery', 'congratulations', 'cash', 'money', 'claim', 'award']
spam_matches = [ind for ind in spam_indicators if ind in text.lower()]
heuristic_spam = min(len(spam_matches) / 6, 1.0) # 1 match = 0.16 (Safe), 2 matches = 0.33 (Low), 3 matches = 0.5 (Medium)
spam_probability = max(spam_probability, heuristic_spam)
# Optional sentiment analysis using pipeline
sentiment_score = 0.0
sentiment_label = "UNKNOWN"
if self.has_pipelines:
try:
sent_result = self.sentiment_pipeline(text[:512])[0]
sentiment_label = sent_result['label']
sentiment_score = sent_result['score'] if sentiment_label == 'NEGATIVE' else (1.0 - sent_result['score'])
except Exception as e:
print(f"Error predicting sentiment: {e}")
results = {
'phishing_probability': combined_phishing,
'prompt_injection': prompt_injection,
'prompt_injection_patterns': injection_patterns,
'ai_generated_probability': ai_generated_score,
'spam_probability': spam_probability,
'spam_ml_score': spam_ml_prob,
'keyword_matches': keyword_matches,
'urgency_matches': urgency_matches,
'ai_patterns': ai_patterns,
'transformer_score': transformer_score,
'using_transformer': True,
'sentiment_score': sentiment_score,
'sentiment_label': sentiment_label,
'connection_score': connection_score,
'connection_message': connection_msg
}
return results |