SpamShield-AI / model.py
M-Arjun's picture
Update model.py
b1dcc43 verified
Raw
History Blame Contribute Delete
4.8 kB
import json
import os
import re
from urllib.parse import urlparse
import numpy as np
import onnxruntime as ort
from utils import preprocess_text
# --------- DEFAULT CONFIG ---------
MODEL_DIR = "models"
SPAM_THRESHOLD = 0.5
SHORT_TEXT_WORD_COUNT = 6
SHORT_TEXT_THRESHOLD = 0.65
VERY_SHORT_TEXT_WORD_COUNT = 3
VERY_SHORT_TEXT_THRESHOLD = 0.75
LONG_TEXT_WORD_THRESHOLD = 80
CHUNK_MAX_WORDS = 40
MAX_CHUNKS = 24
BLOCKED_URL_DOMAINS = set()
METADATA_PATH = os.path.join("/", "metadata.json")
# --------- GLOBALS ---------
_binary_session = None
_category_session = None
_metadata = None
# --------- REGEX ---------
SPAM_HINT_PATTERN = re.compile(
r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|"
r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)",
re.IGNORECASE,
)
SCAM_ACTION_PATTERN = re.compile(
r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)",
re.IGNORECASE,
)
GIVEAWAY_OVERRIDE_PATTERN = re.compile(
r"(\b(won|winner|jackpot)\b.*\b(prize|reward|gift|voucher|iphone|cash)\b)",
re.IGNORECASE,
)
URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE)
LINK_SPAM_CUE_PATTERN = re.compile(
r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|deposit)",
re.IGNORECASE,
)
# --------- LOADERS ---------
def _load_metadata():
if os.path.exists(METADATA_PATH):
with open(METADATA_PATH, "r", encoding="utf-8") as f:
return json.load(f)
return {
"spam_threshold": SPAM_THRESHOLD
}
def load_models():
global _binary_session, _category_session, _metadata
if _binary_session is None:
binary_path = os.path.join(MODEL_DIR, "binary_model.onnx")
_binary_session = ort.InferenceSession(binary_path)
if _category_session is None:
category_path = os.path.join(MODEL_DIR, "category_model.onnx")
_category_session = ort.InferenceSession(category_path)
if _metadata is None:
_metadata = _load_metadata()
# --------- HELPERS ---------
def _contains_url(text):
return bool(URL_ANY_PATTERN.search(text or ""))
def _effective_threshold(text):
threshold = float(_metadata.get("spam_threshold", SPAM_THRESHOLD))
words = text.split()
if len(words) <= VERY_SHORT_TEXT_WORD_COUNT:
threshold = max(threshold, VERY_SHORT_TEXT_THRESHOLD)
elif len(words) <= SHORT_TEXT_WORD_COUNT:
threshold = max(threshold, SHORT_TEXT_THRESHOLD)
return threshold
def _split_text(text):
words = text.split()
chunks = []
for i in range(0, len(words), CHUNK_MAX_WORDS):
chunk = " ".join(words[i:i + CHUNK_MAX_WORDS])
chunks.append(chunk)
if len(chunks) >= MAX_CHUNKS:
break
return chunks
# --------- CORE ---------
def _predict_single(raw_text, cleaned_text):
onnx_input = np.array([[cleaned_text]], dtype=object)
# Binary
inputs = {_binary_session.get_inputs()[0].name: onnx_input}
outputs = _binary_session.run(None, inputs)
probs = outputs[1][0]
spam_prob = float(probs.get(1, probs.get('1', 0.0)))
threshold = _effective_threshold(cleaned_text)
is_spam = spam_prob >= threshold
# Heuristics
if not is_spam:
if GIVEAWAY_OVERRIDE_PATTERN.search(raw_text or ""):
is_spam = True
# Category
if is_spam:
cat_inputs = {_category_session.get_inputs()[0].name: onnx_input}
cat_outputs = _category_session.run(None, cat_inputs)
category = str(cat_outputs[0][0])
else:
category = "normal"
return {
"is_spam": is_spam,
"confidence": spam_prob,
"category": category,
"threshold_used": threshold
}
# --------- PUBLIC API ---------
def predict_message(text):
load_models()
cleaned = preprocess_text(text)
word_count = len(cleaned.split())
if word_count <= LONG_TEXT_WORD_THRESHOLD:
pred = _predict_single(text, cleaned)
pred["confidence"] = round(pred["confidence"], 4)
return pred
chunks = _split_text(text)
preds = [_predict_single(chunk, preprocess_text(chunk)) for chunk in chunks]
best = max(preds, key=lambda x: x["confidence"])
return {
"is_spam": any(p["is_spam"] for p in preds),
"confidence": round(best["confidence"], 4),
"category": best["category"] if best["is_spam"] else "normal",
"threshold_used": best["threshold_used"]
}
def run_model(text):
if not text or not isinstance(text, str):
return {"ok": False, "error": "Invalid input"}
result = predict_message(text)
return {
"ok": True,
"input": text.strip(),
"result": result
}