Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import re | |
| from urllib.parse import urlparse | |
| import numpy as np | |
| import onnxruntime as ort | |
| from utils import preprocess_text | |
| # --------- DEFAULT CONFIG --------- | |
| MODEL_DIR = "models" | |
| SPAM_THRESHOLD = 0.5 | |
| SHORT_TEXT_WORD_COUNT = 6 | |
| SHORT_TEXT_THRESHOLD = 0.65 | |
| VERY_SHORT_TEXT_WORD_COUNT = 3 | |
| VERY_SHORT_TEXT_THRESHOLD = 0.75 | |
| LONG_TEXT_WORD_THRESHOLD = 80 | |
| CHUNK_MAX_WORDS = 40 | |
| MAX_CHUNKS = 24 | |
| BLOCKED_URL_DOMAINS = set() | |
| METADATA_PATH = os.path.join("/", "metadata.json") | |
| # --------- GLOBALS --------- | |
| _binary_session = None | |
| _category_session = None | |
| _metadata = None | |
| # --------- REGEX --------- | |
| SPAM_HINT_PATTERN = re.compile( | |
| r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|" | |
| r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)", | |
| re.IGNORECASE, | |
| ) | |
| SCAM_ACTION_PATTERN = re.compile( | |
| r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)", | |
| re.IGNORECASE, | |
| ) | |
| GIVEAWAY_OVERRIDE_PATTERN = re.compile( | |
| r"(\b(won|winner|jackpot)\b.*\b(prize|reward|gift|voucher|iphone|cash)\b)", | |
| re.IGNORECASE, | |
| ) | |
| URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE) | |
| LINK_SPAM_CUE_PATTERN = re.compile( | |
| r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|deposit)", | |
| re.IGNORECASE, | |
| ) | |
| # --------- LOADERS --------- | |
| def _load_metadata(): | |
| if os.path.exists(METADATA_PATH): | |
| with open(METADATA_PATH, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return { | |
| "spam_threshold": SPAM_THRESHOLD | |
| } | |
| def load_models(): | |
| global _binary_session, _category_session, _metadata | |
| if _binary_session is None: | |
| binary_path = os.path.join(MODEL_DIR, "binary_model.onnx") | |
| _binary_session = ort.InferenceSession(binary_path) | |
| if _category_session is None: | |
| category_path = os.path.join(MODEL_DIR, "category_model.onnx") | |
| _category_session = ort.InferenceSession(category_path) | |
| if _metadata is None: | |
| _metadata = _load_metadata() | |
| # --------- HELPERS --------- | |
| def _contains_url(text): | |
| return bool(URL_ANY_PATTERN.search(text or "")) | |
| def _effective_threshold(text): | |
| threshold = float(_metadata.get("spam_threshold", SPAM_THRESHOLD)) | |
| words = text.split() | |
| if len(words) <= VERY_SHORT_TEXT_WORD_COUNT: | |
| threshold = max(threshold, VERY_SHORT_TEXT_THRESHOLD) | |
| elif len(words) <= SHORT_TEXT_WORD_COUNT: | |
| threshold = max(threshold, SHORT_TEXT_THRESHOLD) | |
| return threshold | |
| def _split_text(text): | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), CHUNK_MAX_WORDS): | |
| chunk = " ".join(words[i:i + CHUNK_MAX_WORDS]) | |
| chunks.append(chunk) | |
| if len(chunks) >= MAX_CHUNKS: | |
| break | |
| return chunks | |
| # --------- CORE --------- | |
| def _predict_single(raw_text, cleaned_text): | |
| onnx_input = np.array([[cleaned_text]], dtype=object) | |
| # Binary | |
| inputs = {_binary_session.get_inputs()[0].name: onnx_input} | |
| outputs = _binary_session.run(None, inputs) | |
| probs = outputs[1][0] | |
| spam_prob = float(probs.get(1, probs.get('1', 0.0))) | |
| threshold = _effective_threshold(cleaned_text) | |
| is_spam = spam_prob >= threshold | |
| # Heuristics | |
| if not is_spam: | |
| if GIVEAWAY_OVERRIDE_PATTERN.search(raw_text or ""): | |
| is_spam = True | |
| # Category | |
| if is_spam: | |
| cat_inputs = {_category_session.get_inputs()[0].name: onnx_input} | |
| cat_outputs = _category_session.run(None, cat_inputs) | |
| category = str(cat_outputs[0][0]) | |
| else: | |
| category = "normal" | |
| return { | |
| "is_spam": is_spam, | |
| "confidence": spam_prob, | |
| "category": category, | |
| "threshold_used": threshold | |
| } | |
| # --------- PUBLIC API --------- | |
| def predict_message(text): | |
| load_models() | |
| cleaned = preprocess_text(text) | |
| word_count = len(cleaned.split()) | |
| if word_count <= LONG_TEXT_WORD_THRESHOLD: | |
| pred = _predict_single(text, cleaned) | |
| pred["confidence"] = round(pred["confidence"], 4) | |
| return pred | |
| chunks = _split_text(text) | |
| preds = [_predict_single(chunk, preprocess_text(chunk)) for chunk in chunks] | |
| best = max(preds, key=lambda x: x["confidence"]) | |
| return { | |
| "is_spam": any(p["is_spam"] for p in preds), | |
| "confidence": round(best["confidence"], 4), | |
| "category": best["category"] if best["is_spam"] else "normal", | |
| "threshold_used": best["threshold_used"] | |
| } | |
| def run_model(text): | |
| if not text or not isinstance(text, str): | |
| return {"ok": False, "error": "Invalid input"} | |
| result = predict_message(text) | |
| return { | |
| "ok": True, | |
| "input": text.strip(), | |
| "result": result | |
| } |