import json import os import re from urllib.parse import urlparse import numpy as np import onnxruntime as ort from utils import preprocess_text # --------- DEFAULT CONFIG --------- MODEL_DIR = "models" SPAM_THRESHOLD = 0.5 SHORT_TEXT_WORD_COUNT = 6 SHORT_TEXT_THRESHOLD = 0.65 VERY_SHORT_TEXT_WORD_COUNT = 3 VERY_SHORT_TEXT_THRESHOLD = 0.75 LONG_TEXT_WORD_THRESHOLD = 80 CHUNK_MAX_WORDS = 40 MAX_CHUNKS = 24 BLOCKED_URL_DOMAINS = set() METADATA_PATH = os.path.join("/", "metadata.json") # --------- GLOBALS --------- _binary_session = None _category_session = None _metadata = None # --------- REGEX --------- SPAM_HINT_PATTERN = re.compile( r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|" r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)", re.IGNORECASE, ) SCAM_ACTION_PATTERN = re.compile( r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)", re.IGNORECASE, ) GIVEAWAY_OVERRIDE_PATTERN = re.compile( r"(\b(won|winner|jackpot)\b.*\b(prize|reward|gift|voucher|iphone|cash)\b)", re.IGNORECASE, ) URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE) LINK_SPAM_CUE_PATTERN = re.compile( r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|deposit)", re.IGNORECASE, ) # --------- LOADERS --------- def _load_metadata(): if os.path.exists(METADATA_PATH): with open(METADATA_PATH, "r", encoding="utf-8") as f: return json.load(f) return { "spam_threshold": SPAM_THRESHOLD } def load_models(): global _binary_session, _category_session, _metadata if _binary_session is None: binary_path = os.path.join(MODEL_DIR, "binary_model.onnx") _binary_session = ort.InferenceSession(binary_path) if _category_session is None: category_path = os.path.join(MODEL_DIR, "category_model.onnx") _category_session = ort.InferenceSession(category_path) if _metadata is None: _metadata = _load_metadata() # --------- HELPERS --------- def _contains_url(text): return bool(URL_ANY_PATTERN.search(text or "")) def _effective_threshold(text): threshold = float(_metadata.get("spam_threshold", SPAM_THRESHOLD)) words = text.split() if len(words) <= VERY_SHORT_TEXT_WORD_COUNT: threshold = max(threshold, VERY_SHORT_TEXT_THRESHOLD) elif len(words) <= SHORT_TEXT_WORD_COUNT: threshold = max(threshold, SHORT_TEXT_THRESHOLD) return threshold def _split_text(text): words = text.split() chunks = [] for i in range(0, len(words), CHUNK_MAX_WORDS): chunk = " ".join(words[i:i + CHUNK_MAX_WORDS]) chunks.append(chunk) if len(chunks) >= MAX_CHUNKS: break return chunks # --------- CORE --------- def _predict_single(raw_text, cleaned_text): onnx_input = np.array([[cleaned_text]], dtype=object) # Binary inputs = {_binary_session.get_inputs()[0].name: onnx_input} outputs = _binary_session.run(None, inputs) probs = outputs[1][0] spam_prob = float(probs.get(1, probs.get('1', 0.0))) threshold = _effective_threshold(cleaned_text) is_spam = spam_prob >= threshold # Heuristics if not is_spam: if GIVEAWAY_OVERRIDE_PATTERN.search(raw_text or ""): is_spam = True # Category if is_spam: cat_inputs = {_category_session.get_inputs()[0].name: onnx_input} cat_outputs = _category_session.run(None, cat_inputs) category = str(cat_outputs[0][0]) else: category = "normal" return { "is_spam": is_spam, "confidence": spam_prob, "category": category, "threshold_used": threshold } # --------- PUBLIC API --------- def predict_message(text): load_models() cleaned = preprocess_text(text) word_count = len(cleaned.split()) if word_count <= LONG_TEXT_WORD_THRESHOLD: pred = _predict_single(text, cleaned) pred["confidence"] = round(pred["confidence"], 4) return pred chunks = _split_text(text) preds = [_predict_single(chunk, preprocess_text(chunk)) for chunk in chunks] best = max(preds, key=lambda x: x["confidence"]) return { "is_spam": any(p["is_spam"] for p in preds), "confidence": round(best["confidence"], 4), "category": best["category"] if best["is_spam"] else "normal", "threshold_used": best["threshold_used"] } def run_model(text): if not text or not isinstance(text, str): return {"ok": False, "error": "Invalid input"} result = predict_message(text) return { "ok": True, "input": text.strip(), "result": result }