Text Classification
Scikit-learn
ONNX
Spam
Spam-Categoriser
SpamShield / 0.4-lite /model.py
M-Arjun's picture
0.4 and 0.4 - lite
0ec6609 verified
import json
import os
import re
from urllib.parse import urlparse
import numpy as np
import onnxruntime as ort
import config
from utils import preprocess_text
# Global variables to hold the loaded models and sessions
_binary_session = None
_category_session = None
_metadata = None
# Signals that are strong spam indicators even in short messages.
SPAM_HINT_PATTERN = re.compile(
r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|"
r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)",
re.IGNORECASE,
)
BENIGN_WIN_CONTEXT_PATTERN = re.compile(
r"\b(won|win|winner)\b.*\b(match|game|tournament|league|race|finals|team|football|cricket|basketball)\b",
re.IGNORECASE,
)
SCAM_ACTION_PATTERN = re.compile(
r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)",
re.IGNORECASE,
)
GIVEAWAY_OVERRIDE_PATTERN = re.compile(
r"(\b(won|winner|jackpot|lucky draw)\b.*\b(lambo|lamboo|prize|reward|gift|voucher|tesla|iphone|cash)\b)|"
r"(\b(claim|redeem)\b.*\b(prize|reward|gift|voucher)\b)",
re.IGNORECASE,
)
SENTENCE_SPLIT_PATTERN = re.compile(r"(?<=[.!?])\s+|\n+")
URL_TOKEN_PATTERN = re.compile(r"^(https?://\S+|www\.\S+)$", re.IGNORECASE)
URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE)
LINK_SPAM_CUE_PATTERN = re.compile(
r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|"
r"deposit|investment|gift card|limited time|act now|suspended|login|otp|kyc)",
re.IGNORECASE,
)
BENIGN_ADULT_CONTEXT_PATTERN = re.compile(
r"(older than 18|over 18|under 18|age requirement|adult supervision|age limit|"
r"content rating|parental guidance|legal age|years old)",
re.IGNORECASE,
)
BENIGN_WORK_CONTEXT_PATTERN = re.compile(
r"(pull request|code review|deployment|sprint|bug fix|qa|release note|"
r"project update|meeting notes|standup|ticket|merge request|ci pipeline)",
re.IGNORECASE,
)
SHORT_BRAND_ALERT_PATTERN = re.compile(
r"^[A-Za-z0-9&'._+-]{2,32}\s*:\s*[^:]{2,90}[.!?]?$"
)
MONEY_JOB_SCAM_PATTERN = re.compile(
r"(\$\s?\d[\d,]*(?:\.\d+)?\s*/?\s*(day|week|month))|"
r"(earn\s+\$?\s?\d[\d,]*)|"
r"(get rich quick)|"
r"(no experience needed)",
re.IGNORECASE,
)
def _load_metadata():
if os.path.exists(config.METADATA_PATH):
with open(config.METADATA_PATH, "r", encoding="utf-8") as f:
return json.load(f)
return {
"spam_threshold": config.SPAM_THRESHOLD,
"short_text_word_count": config.SHORT_TEXT_WORD_COUNT,
"short_text_threshold": config.SHORT_TEXT_THRESHOLD,
"very_short_text_word_count": config.VERY_SHORT_TEXT_WORD_COUNT,
"very_short_text_threshold": config.VERY_SHORT_TEXT_THRESHOLD,
}
def load_models():
"""Loads the model sessions from disk only once."""
global _binary_session, _category_session, _metadata
if _binary_session is None:
onnx_path = os.path.join(config.MODEL_DIR, "binary_model.onnx")
_binary_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
if _category_session is None:
onnx_path = os.path.join(config.MODEL_DIR, "category_model.onnx")
_category_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
if _metadata is None:
_metadata = _load_metadata()
def _effective_threshold(raw_text, cleaned_text):
threshold = float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD))
short_word_count = int(
_metadata.get("short_text_word_count", config.SHORT_TEXT_WORD_COUNT)
)
short_threshold = float(
_metadata.get("short_text_threshold", config.SHORT_TEXT_THRESHOLD)
)
very_short_word_count = int(
_metadata.get("very_short_text_word_count", config.VERY_SHORT_TEXT_WORD_COUNT)
)
very_short_threshold = float(
_metadata.get("very_short_text_threshold", config.VERY_SHORT_TEXT_THRESHOLD)
)
words = [w for w in cleaned_text.split(" ") if w]
has_spam_hint = bool(SPAM_HINT_PATTERN.search(raw_text or ""))
if not has_spam_hint:
if len(words) <= very_short_word_count:
threshold = max(threshold, very_short_threshold)
elif len(words) <= short_word_count:
threshold = max(threshold, short_threshold)
return threshold
def _is_benign_win_context(raw_text):
if not raw_text:
return False
return bool(BENIGN_WIN_CONTEXT_PATTERN.search(raw_text)) and not bool(
SCAM_ACTION_PATTERN.search(raw_text)
)
def _is_benign_context(raw_text):
if not raw_text:
return False
if bool(SCAM_ACTION_PATTERN.search(raw_text)):
return False
return bool(BENIGN_ADULT_CONTEXT_PATTERN.search(raw_text)) or bool(
BENIGN_WORK_CONTEXT_PATTERN.search(raw_text)
)
def _extract_url_domains(raw_text: str) -> list[str]:
if not raw_text:
return []
domains = []
for m in URL_ANY_PATTERN.finditer(raw_text):
url = m.group(0).strip()
if url.lower().startswith("www."):
url = "https://" + url
try:
parsed = urlparse(url)
host = (parsed.netloc or "").lower().strip()
except Exception:
continue
if not host:
continue
if host.startswith("m."):
host = host[2:]
domains.append(host)
return domains
def _has_blocked_domain(raw_text: str) -> bool:
blocked = set(getattr(config, "BLOCKED_URL_DOMAINS", set()))
if not blocked:
return False
domains = _extract_url_domains(raw_text)
if not domains:
return False
for host in domains:
if host in blocked:
return True
for base in blocked:
if host.endswith('.' + base):
return True
return False
def _contains_url(raw_text: str) -> bool:
return bool(URL_ANY_PATTERN.search(raw_text or ""))
def _has_link_spam_cues(raw_text: str) -> bool:
return bool(LINK_SPAM_CUE_PATTERN.search(raw_text or ""))
def _split_long_text(text: str) -> list[str]:
max_words = int(getattr(config, "CHUNK_MAX_WORDS", 40))
max_chunks = int(getattr(config, "MAX_CHUNKS", 24))
parts = [p.strip() for p in SENTENCE_SPLIT_PATTERN.split(text or "") if p.strip()]
chunks = []
current = []
current_words = 0
for part in parts:
words = part.split()
if not words:
continue
if len(words) > max_words:
for i in range(0, len(words), max_words):
piece = " ".join(words[i : i + max_words]).strip()
if piece:
chunks.append(piece)
if len(chunks) >= max_chunks:
return chunks[:max_chunks]
continue
if current_words + len(words) > max_words and current:
chunks.append(" ".join(current).strip())
current = [part]
current_words = len(words)
else:
current.append(part)
current_words += len(words)
if len(chunks) >= max_chunks:
return chunks[:max_chunks]
if current and len(chunks) < max_chunks:
chunks.append(" ".join(current).strip())
return chunks[:max_chunks]
def _predict_single(raw_text: str, cleaned_text: str) -> dict:
if _contains_url(raw_text):
if _has_blocked_domain(raw_text):
return {
"is_spam": True,
"confidence": 0.99,
"category": "spam",
"threshold_used": float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD)),
}
if not _has_link_spam_cues(raw_text):
return {
"is_spam": False,
"confidence": 0.05,
"category": "normal",
"threshold_used": float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD)),
}
# Prepare input for ONNX
# Input name was set to 'input' in conversion script.
# It expects StringTensorType([None, 1])
onnx_input = np.array([[cleaned_text]], dtype=object)
# Binary prediction
binary_inputs = {_binary_session.get_inputs()[0].name: onnx_input}
# Output names are usually 'label' and 'probabilities'
binary_outputs = _binary_session.run(None, binary_inputs)
# binary_outputs[1] is a list of dictionaries like [{'0': 0.9, '1': 0.1}]
# Let's verify the actual output format.
# usually it's [labels, [{0: prob, 1: prob}]]
probs = binary_outputs[1][0]
spam_prob = float(probs.get(1, probs.get('1', 0.0)))
threshold = _effective_threshold(raw_text, cleaned_text)
is_spam = spam_prob >= threshold
if is_spam and spam_prob < 0.92 and _is_benign_win_context(raw_text):
is_spam = False
if is_spam and spam_prob < 0.85 and _is_benign_context(raw_text):
is_spam = False
has_giveaway_override = bool(GIVEAWAY_OVERRIDE_PATTERN.search(raw_text or ""))
if not is_spam and has_giveaway_override and not _is_benign_win_context(raw_text):
is_spam = True
short_brand_alert = bool(SHORT_BRAND_ALERT_PATTERN.match((raw_text or "").strip()))
money_job_scam = bool(MONEY_JOB_SCAM_PATTERN.search(raw_text or ""))
if cleaned_text.strip().lower() == "join now":
is_spam = False
if not is_spam:
if money_job_scam and spam_prob >= max(0.55, threshold - 0.20):
is_spam = True
elif short_brand_alert and spam_prob >= max(0.50, threshold - 0.16):
is_spam = True
if is_spam:
if money_job_scam:
category = "job_scam"
elif short_brand_alert:
category = "phishing"
elif has_giveaway_override:
category = "giveaway"
else:
category_inputs = {_category_session.get_inputs()[0].name: onnx_input}
category_outputs = _category_session.run(None, category_inputs)
category = str(category_outputs[0][0])
else:
category = "normal"
return {
"is_spam": bool(is_spam),
"confidence": float(spam_prob),
"category": str(category),
"threshold_used": float(threshold),
}
def validate_message(text: str) -> tuple[bool, str]:
if text is None:
return False, "Input is required."
if not isinstance(text, str):
return False, "Input must be a string."
normalized = text.strip()
if not normalized:
return False, "Input cannot be empty."
if not any(ch.isalnum() for ch in normalized):
return True, ""
if len(normalized) < 2:
return True, ""
return True, ""
def predict_message(text: str) -> dict:
load_models()
cleaned_text = preprocess_text(text)
word_count = len([w for w in cleaned_text.split(" ") if w])
long_threshold = int(getattr(config, "LONG_TEXT_WORD_THRESHOLD", 80))
if word_count <= long_threshold:
pred = _predict_single(text, cleaned_text)
return {
"is_spam": pred["is_spam"],
"confidence": round(pred["confidence"], 4),
"category": pred["category"],
"threshold_used": round(pred["threshold_used"], 4),
"chunked": False,
}
chunks = _split_long_text(text)
if not chunks:
pred = _predict_single(text, cleaned_text)
return {
"is_spam": pred["is_spam"],
"confidence": round(pred["confidence"], 4),
"category": pred["category"],
"threshold_used": round(pred["threshold_used"], 4),
"chunked": False,
}
chunk_predictions = []
for chunk in chunks:
cp = _predict_single(chunk, preprocess_text(chunk))
chunk_predictions.append(cp)
highest = max(chunk_predictions, key=lambda x: x["confidence"])
spam_chunks = [cp for cp in chunk_predictions if cp["is_spam"]]
is_spam = len(spam_chunks) > 0
if is_spam:
representative = max(spam_chunks, key=lambda x: x["confidence"])
else:
representative = highest
return {
"is_spam": bool(is_spam),
"confidence": round(float(highest["confidence"]), 4),
"category": str(representative["category"] if is_spam else "normal"),
"threshold_used": round(float(representative["threshold_used"]), 4),
"chunked": True,
"chunk_count": len(chunks),
}
def run_model(text: str) -> dict:
ok, error = validate_message(text)
if not ok:
return {
"ok": False,
"error": error,
"input": text,
}
prediction = predict_message(text)
return {
"ok": True,
"input": text.strip(),
"result": prediction,
}
def update_model(text: str, label: int, category: str):
if text is None:
return
os.makedirs("dataset", exist_ok=True)
feedback_path = os.path.join("dataset", "feedback.jsonl")
payload = {
"text": text.strip(),
"label": int(label),
"category": str(category),
}
with open(feedback_path, "a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=True) + "\n")
def get_model_specs() -> dict:
specs = {
"model_dir": config.MODEL_DIR,
"binary_model_path": config.BINARY_MODEL_PATH,
"category_model_path": config.CATEGORY_MODEL_PATH,
"metadata_path": config.METADATA_PATH,
"spam_threshold": config.SPAM_THRESHOLD,
"word_max_features": getattr(config, "WORD_MAX_FEATURES", None),
"char_max_features": getattr(config, "CHAR_MAX_FEATURES", None),
"files": {
"binary_onnx_exists": os.path.exists(os.path.join(config.MODEL_DIR, "binary_model.onnx")),
"category_onnx_exists": os.path.exists(os.path.join(config.MODEL_DIR, "category_model.onnx")),
"metadata_exists": os.path.exists(config.METADATA_PATH),
},
}
try:
load_models()
specs["loaded"] = True
specs["runtime_threshold"] = _metadata.get("spam_threshold", config.SPAM_THRESHOLD)
except Exception as exc:
specs["loaded"] = False
specs["load_error"] = str(exc)
return specs
def print_model_specs() -> None:
specs = get_model_specs()
print("Model Specs (ONNX)")
print(f"- Model dir: {specs['model_dir']}")
print(f"- Base threshold: {specs['spam_threshold']}")
print(
"- Files exist: "
f"binary_onnx={specs['files']['binary_onnx_exists']}, "
f"category_onnx={specs['files']['category_onnx_exists']}, "
f"metadata={specs['files']['metadata_exists']}"
)
if specs.get("loaded"):
print("- Loaded: True")
print(f"- Runtime threshold: {specs['runtime_threshold']}")
else:
print("- Loaded: False")
print(f"- Load error: {specs.get('load_error', 'unknown error')}")
if __name__ == "__main__":
print_model_specs()