Instructions to use M-Arjun/SpamShield with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Scikit-learn
How to use M-Arjun/SpamShield with Scikit-learn:
from huggingface_hub import hf_hub_download import joblib model = joblib.load( hf_hub_download("M-Arjun/SpamShield", "sklearn_model.joblib") ) # only load pickle files from sources you trust # read more about it here https://skops.readthedocs.io/en/stable/persistence.html - Notebooks
- Google Colab
- Kaggle
| import json | |
| import os | |
| import re | |
| from urllib.parse import urlparse | |
| import numpy as np | |
| import onnxruntime as ort | |
| import config | |
| from utils import preprocess_text | |
| # Global variables to hold the loaded models and sessions | |
| _binary_session = None | |
| _category_session = None | |
| _metadata = None | |
| # Signals that are strong spam indicators even in short messages. | |
| SPAM_HINT_PATTERN = re.compile( | |
| r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|" | |
| r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)", | |
| re.IGNORECASE, | |
| ) | |
| BENIGN_WIN_CONTEXT_PATTERN = re.compile( | |
| r"\b(won|win|winner)\b.*\b(match|game|tournament|league|race|finals|team|football|cricket|basketball)\b", | |
| re.IGNORECASE, | |
| ) | |
| SCAM_ACTION_PATTERN = re.compile( | |
| r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)", | |
| re.IGNORECASE, | |
| ) | |
| GIVEAWAY_OVERRIDE_PATTERN = re.compile( | |
| r"(\b(won|winner|jackpot|lucky draw)\b.*\b(lambo|lamboo|prize|reward|gift|voucher|tesla|iphone|cash)\b)|" | |
| r"(\b(claim|redeem)\b.*\b(prize|reward|gift|voucher)\b)", | |
| re.IGNORECASE, | |
| ) | |
| SENTENCE_SPLIT_PATTERN = re.compile(r"(?<=[.!?])\s+|\n+") | |
| URL_TOKEN_PATTERN = re.compile(r"^(https?://\S+|www\.\S+)$", re.IGNORECASE) | |
| URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE) | |
| LINK_SPAM_CUE_PATTERN = re.compile( | |
| r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|" | |
| r"deposit|investment|gift card|limited time|act now|suspended|login|otp|kyc)", | |
| re.IGNORECASE, | |
| ) | |
| BENIGN_ADULT_CONTEXT_PATTERN = re.compile( | |
| r"(older than 18|over 18|under 18|age requirement|adult supervision|age limit|" | |
| r"content rating|parental guidance|legal age|years old)", | |
| re.IGNORECASE, | |
| ) | |
| BENIGN_WORK_CONTEXT_PATTERN = re.compile( | |
| r"(pull request|code review|deployment|sprint|bug fix|qa|release note|" | |
| r"project update|meeting notes|standup|ticket|merge request|ci pipeline)", | |
| re.IGNORECASE, | |
| ) | |
| SHORT_BRAND_ALERT_PATTERN = re.compile( | |
| r"^[A-Za-z0-9&'._+-]{2,32}\s*:\s*[^:]{2,90}[.!?]?$" | |
| ) | |
| MONEY_JOB_SCAM_PATTERN = re.compile( | |
| r"(\$\s?\d[\d,]*(?:\.\d+)?\s*/?\s*(day|week|month))|" | |
| r"(earn\s+\$?\s?\d[\d,]*)|" | |
| r"(get rich quick)|" | |
| r"(no experience needed)", | |
| re.IGNORECASE, | |
| ) | |
| def _load_metadata(): | |
| if os.path.exists(config.METADATA_PATH): | |
| with open(config.METADATA_PATH, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return { | |
| "spam_threshold": config.SPAM_THRESHOLD, | |
| "short_text_word_count": config.SHORT_TEXT_WORD_COUNT, | |
| "short_text_threshold": config.SHORT_TEXT_THRESHOLD, | |
| "very_short_text_word_count": config.VERY_SHORT_TEXT_WORD_COUNT, | |
| "very_short_text_threshold": config.VERY_SHORT_TEXT_THRESHOLD, | |
| } | |
| def load_models(): | |
| """Loads the model sessions from disk only once.""" | |
| global _binary_session, _category_session, _metadata | |
| if _binary_session is None: | |
| onnx_path = os.path.join(config.MODEL_DIR, "binary_model.onnx") | |
| _binary_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) | |
| if _category_session is None: | |
| onnx_path = os.path.join(config.MODEL_DIR, "category_model.onnx") | |
| _category_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) | |
| if _metadata is None: | |
| _metadata = _load_metadata() | |
| def _effective_threshold(raw_text, cleaned_text): | |
| threshold = float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD)) | |
| short_word_count = int( | |
| _metadata.get("short_text_word_count", config.SHORT_TEXT_WORD_COUNT) | |
| ) | |
| short_threshold = float( | |
| _metadata.get("short_text_threshold", config.SHORT_TEXT_THRESHOLD) | |
| ) | |
| very_short_word_count = int( | |
| _metadata.get("very_short_text_word_count", config.VERY_SHORT_TEXT_WORD_COUNT) | |
| ) | |
| very_short_threshold = float( | |
| _metadata.get("very_short_text_threshold", config.VERY_SHORT_TEXT_THRESHOLD) | |
| ) | |
| words = [w for w in cleaned_text.split(" ") if w] | |
| has_spam_hint = bool(SPAM_HINT_PATTERN.search(raw_text or "")) | |
| if not has_spam_hint: | |
| if len(words) <= very_short_word_count: | |
| threshold = max(threshold, very_short_threshold) | |
| elif len(words) <= short_word_count: | |
| threshold = max(threshold, short_threshold) | |
| return threshold | |
| def _is_benign_win_context(raw_text): | |
| if not raw_text: | |
| return False | |
| return bool(BENIGN_WIN_CONTEXT_PATTERN.search(raw_text)) and not bool( | |
| SCAM_ACTION_PATTERN.search(raw_text) | |
| ) | |
| def _is_benign_context(raw_text): | |
| if not raw_text: | |
| return False | |
| if bool(SCAM_ACTION_PATTERN.search(raw_text)): | |
| return False | |
| return bool(BENIGN_ADULT_CONTEXT_PATTERN.search(raw_text)) or bool( | |
| BENIGN_WORK_CONTEXT_PATTERN.search(raw_text) | |
| ) | |
| def _extract_url_domains(raw_text: str) -> list[str]: | |
| if not raw_text: | |
| return [] | |
| domains = [] | |
| for m in URL_ANY_PATTERN.finditer(raw_text): | |
| url = m.group(0).strip() | |
| if url.lower().startswith("www."): | |
| url = "https://" + url | |
| try: | |
| parsed = urlparse(url) | |
| host = (parsed.netloc or "").lower().strip() | |
| except Exception: | |
| continue | |
| if not host: | |
| continue | |
| if host.startswith("m."): | |
| host = host[2:] | |
| domains.append(host) | |
| return domains | |
| def _has_blocked_domain(raw_text: str) -> bool: | |
| blocked = set(getattr(config, "BLOCKED_URL_DOMAINS", set())) | |
| if not blocked: | |
| return False | |
| domains = _extract_url_domains(raw_text) | |
| if not domains: | |
| return False | |
| for host in domains: | |
| if host in blocked: | |
| return True | |
| for base in blocked: | |
| if host.endswith('.' + base): | |
| return True | |
| return False | |
| def _contains_url(raw_text: str) -> bool: | |
| return bool(URL_ANY_PATTERN.search(raw_text or "")) | |
| def _has_link_spam_cues(raw_text: str) -> bool: | |
| return bool(LINK_SPAM_CUE_PATTERN.search(raw_text or "")) | |
| def _split_long_text(text: str) -> list[str]: | |
| max_words = int(getattr(config, "CHUNK_MAX_WORDS", 40)) | |
| max_chunks = int(getattr(config, "MAX_CHUNKS", 24)) | |
| parts = [p.strip() for p in SENTENCE_SPLIT_PATTERN.split(text or "") if p.strip()] | |
| chunks = [] | |
| current = [] | |
| current_words = 0 | |
| for part in parts: | |
| words = part.split() | |
| if not words: | |
| continue | |
| if len(words) > max_words: | |
| for i in range(0, len(words), max_words): | |
| piece = " ".join(words[i : i + max_words]).strip() | |
| if piece: | |
| chunks.append(piece) | |
| if len(chunks) >= max_chunks: | |
| return chunks[:max_chunks] | |
| continue | |
| if current_words + len(words) > max_words and current: | |
| chunks.append(" ".join(current).strip()) | |
| current = [part] | |
| current_words = len(words) | |
| else: | |
| current.append(part) | |
| current_words += len(words) | |
| if len(chunks) >= max_chunks: | |
| return chunks[:max_chunks] | |
| if current and len(chunks) < max_chunks: | |
| chunks.append(" ".join(current).strip()) | |
| return chunks[:max_chunks] | |
| def _predict_single(raw_text: str, cleaned_text: str) -> dict: | |
| if _contains_url(raw_text): | |
| if _has_blocked_domain(raw_text): | |
| return { | |
| "is_spam": True, | |
| "confidence": 0.99, | |
| "category": "spam", | |
| "threshold_used": float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD)), | |
| } | |
| if not _has_link_spam_cues(raw_text): | |
| return { | |
| "is_spam": False, | |
| "confidence": 0.05, | |
| "category": "normal", | |
| "threshold_used": float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD)), | |
| } | |
| # Prepare input for ONNX | |
| # Input name was set to 'input' in conversion script. | |
| # It expects StringTensorType([None, 1]) | |
| onnx_input = np.array([[cleaned_text]], dtype=object) | |
| # Binary prediction | |
| binary_inputs = {_binary_session.get_inputs()[0].name: onnx_input} | |
| # Output names are usually 'label' and 'probabilities' | |
| binary_outputs = _binary_session.run(None, binary_inputs) | |
| # binary_outputs[1] is a list of dictionaries like [{'0': 0.9, '1': 0.1}] | |
| # Let's verify the actual output format. | |
| # usually it's [labels, [{0: prob, 1: prob}]] | |
| probs = binary_outputs[1][0] | |
| spam_prob = float(probs.get(1, probs.get('1', 0.0))) | |
| threshold = _effective_threshold(raw_text, cleaned_text) | |
| is_spam = spam_prob >= threshold | |
| if is_spam and spam_prob < 0.92 and _is_benign_win_context(raw_text): | |
| is_spam = False | |
| if is_spam and spam_prob < 0.85 and _is_benign_context(raw_text): | |
| is_spam = False | |
| has_giveaway_override = bool(GIVEAWAY_OVERRIDE_PATTERN.search(raw_text or "")) | |
| if not is_spam and has_giveaway_override and not _is_benign_win_context(raw_text): | |
| is_spam = True | |
| short_brand_alert = bool(SHORT_BRAND_ALERT_PATTERN.match((raw_text or "").strip())) | |
| money_job_scam = bool(MONEY_JOB_SCAM_PATTERN.search(raw_text or "")) | |
| if cleaned_text.strip().lower() == "join now": | |
| is_spam = False | |
| if not is_spam: | |
| if money_job_scam and spam_prob >= max(0.55, threshold - 0.20): | |
| is_spam = True | |
| elif short_brand_alert and spam_prob >= max(0.50, threshold - 0.16): | |
| is_spam = True | |
| if is_spam: | |
| if money_job_scam: | |
| category = "job_scam" | |
| elif short_brand_alert: | |
| category = "phishing" | |
| elif has_giveaway_override: | |
| category = "giveaway" | |
| else: | |
| category_inputs = {_category_session.get_inputs()[0].name: onnx_input} | |
| category_outputs = _category_session.run(None, category_inputs) | |
| category = str(category_outputs[0][0]) | |
| else: | |
| category = "normal" | |
| return { | |
| "is_spam": bool(is_spam), | |
| "confidence": float(spam_prob), | |
| "category": str(category), | |
| "threshold_used": float(threshold), | |
| } | |
| def validate_message(text: str) -> tuple[bool, str]: | |
| if text is None: | |
| return False, "Input is required." | |
| if not isinstance(text, str): | |
| return False, "Input must be a string." | |
| normalized = text.strip() | |
| if not normalized: | |
| return False, "Input cannot be empty." | |
| if not any(ch.isalnum() for ch in normalized): | |
| return True, "" | |
| if len(normalized) < 2: | |
| return True, "" | |
| return True, "" | |
| def predict_message(text: str) -> dict: | |
| load_models() | |
| cleaned_text = preprocess_text(text) | |
| word_count = len([w for w in cleaned_text.split(" ") if w]) | |
| long_threshold = int(getattr(config, "LONG_TEXT_WORD_THRESHOLD", 80)) | |
| if word_count <= long_threshold: | |
| pred = _predict_single(text, cleaned_text) | |
| return { | |
| "is_spam": pred["is_spam"], | |
| "confidence": round(pred["confidence"], 4), | |
| "category": pred["category"], | |
| "threshold_used": round(pred["threshold_used"], 4), | |
| "chunked": False, | |
| } | |
| chunks = _split_long_text(text) | |
| if not chunks: | |
| pred = _predict_single(text, cleaned_text) | |
| return { | |
| "is_spam": pred["is_spam"], | |
| "confidence": round(pred["confidence"], 4), | |
| "category": pred["category"], | |
| "threshold_used": round(pred["threshold_used"], 4), | |
| "chunked": False, | |
| } | |
| chunk_predictions = [] | |
| for chunk in chunks: | |
| cp = _predict_single(chunk, preprocess_text(chunk)) | |
| chunk_predictions.append(cp) | |
| highest = max(chunk_predictions, key=lambda x: x["confidence"]) | |
| spam_chunks = [cp for cp in chunk_predictions if cp["is_spam"]] | |
| is_spam = len(spam_chunks) > 0 | |
| if is_spam: | |
| representative = max(spam_chunks, key=lambda x: x["confidence"]) | |
| else: | |
| representative = highest | |
| return { | |
| "is_spam": bool(is_spam), | |
| "confidence": round(float(highest["confidence"]), 4), | |
| "category": str(representative["category"] if is_spam else "normal"), | |
| "threshold_used": round(float(representative["threshold_used"]), 4), | |
| "chunked": True, | |
| "chunk_count": len(chunks), | |
| } | |
| def run_model(text: str) -> dict: | |
| ok, error = validate_message(text) | |
| if not ok: | |
| return { | |
| "ok": False, | |
| "error": error, | |
| "input": text, | |
| } | |
| prediction = predict_message(text) | |
| return { | |
| "ok": True, | |
| "input": text.strip(), | |
| "result": prediction, | |
| } | |
| def update_model(text: str, label: int, category: str): | |
| if text is None: | |
| return | |
| os.makedirs("dataset", exist_ok=True) | |
| feedback_path = os.path.join("dataset", "feedback.jsonl") | |
| payload = { | |
| "text": text.strip(), | |
| "label": int(label), | |
| "category": str(category), | |
| } | |
| with open(feedback_path, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(payload, ensure_ascii=True) + "\n") | |
| def get_model_specs() -> dict: | |
| specs = { | |
| "model_dir": config.MODEL_DIR, | |
| "binary_model_path": config.BINARY_MODEL_PATH, | |
| "category_model_path": config.CATEGORY_MODEL_PATH, | |
| "metadata_path": config.METADATA_PATH, | |
| "spam_threshold": config.SPAM_THRESHOLD, | |
| "word_max_features": getattr(config, "WORD_MAX_FEATURES", None), | |
| "char_max_features": getattr(config, "CHAR_MAX_FEATURES", None), | |
| "files": { | |
| "binary_onnx_exists": os.path.exists(os.path.join(config.MODEL_DIR, "binary_model.onnx")), | |
| "category_onnx_exists": os.path.exists(os.path.join(config.MODEL_DIR, "category_model.onnx")), | |
| "metadata_exists": os.path.exists(config.METADATA_PATH), | |
| }, | |
| } | |
| try: | |
| load_models() | |
| specs["loaded"] = True | |
| specs["runtime_threshold"] = _metadata.get("spam_threshold", config.SPAM_THRESHOLD) | |
| except Exception as exc: | |
| specs["loaded"] = False | |
| specs["load_error"] = str(exc) | |
| return specs | |
| def print_model_specs() -> None: | |
| specs = get_model_specs() | |
| print("Model Specs (ONNX)") | |
| print(f"- Model dir: {specs['model_dir']}") | |
| print(f"- Base threshold: {specs['spam_threshold']}") | |
| print( | |
| "- Files exist: " | |
| f"binary_onnx={specs['files']['binary_onnx_exists']}, " | |
| f"category_onnx={specs['files']['category_onnx_exists']}, " | |
| f"metadata={specs['files']['metadata_exists']}" | |
| ) | |
| if specs.get("loaded"): | |
| print("- Loaded: True") | |
| print(f"- Runtime threshold: {specs['runtime_threshold']}") | |
| else: | |
| print("- Loaded: False") | |
| print(f"- Load error: {specs.get('load_error', 'unknown error')}") | |
| if __name__ == "__main__": | |
| print_model_specs() | |