SpamShield-Demo / model.py
M-Arjun's picture
Update model.py
b1dcc43 verified
import json
import os
import re
from urllib.parse import urlparse
import numpy as np
import onnxruntime as ort
from utils import preprocess_text
# --------- DEFAULT CONFIG ---------
MODEL_DIR = "models"
SPAM_THRESHOLD = 0.5
SHORT_TEXT_WORD_COUNT = 6
SHORT_TEXT_THRESHOLD = 0.65
VERY_SHORT_TEXT_WORD_COUNT = 3
VERY_SHORT_TEXT_THRESHOLD = 0.75
LONG_TEXT_WORD_THRESHOLD = 80
CHUNK_MAX_WORDS = 40
MAX_CHUNKS = 24
BLOCKED_URL_DOMAINS = set()
METADATA_PATH = os.path.join("/", "metadata.json")
# --------- GLOBALS ---------
_binary_session = None
_category_session = None
_metadata = None
# --------- REGEX ---------
SPAM_HINT_PATTERN = re.compile(
r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|"
r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)",
re.IGNORECASE,
)
SCAM_ACTION_PATTERN = re.compile(
r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)",
re.IGNORECASE,
)
GIVEAWAY_OVERRIDE_PATTERN = re.compile(
r"(\b(won|winner|jackpot)\b.*\b(prize|reward|gift|voucher|iphone|cash)\b)",
re.IGNORECASE,
)
URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE)
LINK_SPAM_CUE_PATTERN = re.compile(
r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|deposit)",
re.IGNORECASE,
)
# --------- LOADERS ---------
def _load_metadata():
if os.path.exists(METADATA_PATH):
with open(METADATA_PATH, "r", encoding="utf-8") as f:
return json.load(f)
return {
"spam_threshold": SPAM_THRESHOLD
}
def load_models():
global _binary_session, _category_session, _metadata
if _binary_session is None:
binary_path = os.path.join(MODEL_DIR, "binary_model.onnx")
_binary_session = ort.InferenceSession(binary_path)
if _category_session is None:
category_path = os.path.join(MODEL_DIR, "category_model.onnx")
_category_session = ort.InferenceSession(category_path)
if _metadata is None:
_metadata = _load_metadata()
# --------- HELPERS ---------
def _contains_url(text):
return bool(URL_ANY_PATTERN.search(text or ""))
def _effective_threshold(text):
threshold = float(_metadata.get("spam_threshold", SPAM_THRESHOLD))
words = text.split()
if len(words) <= VERY_SHORT_TEXT_WORD_COUNT:
threshold = max(threshold, VERY_SHORT_TEXT_THRESHOLD)
elif len(words) <= SHORT_TEXT_WORD_COUNT:
threshold = max(threshold, SHORT_TEXT_THRESHOLD)
return threshold
def _split_text(text):
words = text.split()
chunks = []
for i in range(0, len(words), CHUNK_MAX_WORDS):
chunk = " ".join(words[i:i + CHUNK_MAX_WORDS])
chunks.append(chunk)
if len(chunks) >= MAX_CHUNKS:
break
return chunks
# --------- CORE ---------
def _predict_single(raw_text, cleaned_text):
onnx_input = np.array([[cleaned_text]], dtype=object)
# Binary
inputs = {_binary_session.get_inputs()[0].name: onnx_input}
outputs = _binary_session.run(None, inputs)
probs = outputs[1][0]
spam_prob = float(probs.get(1, probs.get('1', 0.0)))
threshold = _effective_threshold(cleaned_text)
is_spam = spam_prob >= threshold
# Heuristics
if not is_spam:
if GIVEAWAY_OVERRIDE_PATTERN.search(raw_text or ""):
is_spam = True
# Category
if is_spam:
cat_inputs = {_category_session.get_inputs()[0].name: onnx_input}
cat_outputs = _category_session.run(None, cat_inputs)
category = str(cat_outputs[0][0])
else:
category = "normal"
return {
"is_spam": is_spam,
"confidence": spam_prob,
"category": category,
"threshold_used": threshold
}
# --------- PUBLIC API ---------
def predict_message(text):
load_models()
cleaned = preprocess_text(text)
word_count = len(cleaned.split())
if word_count <= LONG_TEXT_WORD_THRESHOLD:
pred = _predict_single(text, cleaned)
pred["confidence"] = round(pred["confidence"], 4)
return pred
chunks = _split_text(text)
preds = [_predict_single(chunk, preprocess_text(chunk)) for chunk in chunks]
best = max(preds, key=lambda x: x["confidence"])
return {
"is_spam": any(p["is_spam"] for p in preds),
"confidence": round(best["confidence"], 4),
"category": best["category"] if best["is_spam"] else "normal",
"threshold_used": best["threshold_used"]
}
def run_model(text):
if not text or not isinstance(text, str):
return {"ok": False, "error": "Invalid input"}
result = predict_message(text)
return {
"ok": True,
"input": text.strip(),
"result": result
}