import os os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") import sys import re import json sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import numpy as np import torch from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from src.config import checkpoints, device, max_seq_len, data_processed, numeric_features _model = None _model_info = None _tokenizer = None _numeric_mean = None _numeric_std = None _threshold = 0.5 def load_model(): global _model, _model_info, _tokenizer, _numeric_mean, _numeric_std, _threshold if _model is not None: return proc_path = os.path.join(data_processed, "processed.pt") if os.path.exists(proc_path): proc_data = torch.load(proc_path, weights_only=False) _numeric_mean = proc_data.get("numeric_mean") _numeric_std = proc_data.get("numeric_std") info_path = os.path.join(checkpoints, "best_model_info.json") if not os.path.exists(info_path): raise FileNotFoundError("No trained model. Run: python src/train.py") with open(info_path) as f: _model_info = json.load(f) name = _model_info["model_name"] model_type = _model_info.get("model_type", "neural") _threshold = float(_model_info.get("threshold", 0.5)) if model_type == "xgboost": import xgboost as xgb _model = xgb.XGBClassifier() _model.load_model(os.path.join(checkpoints, f"{name}_best.json")) _tokenizer = None else: from src.data import GloveVocab _tokenizer = GloveVocab.load(os.path.join(checkpoints, "vocab.json")) from src.models import BiGRU_LSTM, CNN_BiLSTM _model = BiGRU_LSTM(vocab_size=_tokenizer.vocab_size) if name == "bigru_lstm" else CNN_BiLSTM(vocab_size=_tokenizer.vocab_size) ckpt = os.path.join(checkpoints, f"{name}_best.pt") _model.load_state_dict(torch.load(ckpt, map_location="cpu", weights_only=True)) _model.to(device) _model.eval() def prepare_text(profile): parts = [] bio = str(profile.get("bio", "") or profile.get("description", "") or "") if bio.strip(): parts.append(bio.strip()) for t in (profile.get("recent_tweets", []) or [])[:20]: t = str(t).strip() if t: parts.append(t) combined = " [SEP] ".join(parts) combined = re.sub(r"http\S+", "", combined) return re.sub(r"\s+", " ", combined).strip() or "" def extract_numeric(profile): followers = float(profile.get("followers_count", 0)) friends = float(profile.get("following_count", 0) or profile.get("friends_count", 0)) statuses = float(profile.get("tweet_count", 0) or profile.get("statuses_count", 0)) favourites = float(profile.get("favourites_count", 0)) age = max(float(profile.get("account_age_days", 365)), 1.0) tweets_per_day = statuses / age bio = str(profile.get("bio", "") or profile.get("description", "") or "") username = str(profile.get("username", "") or profile.get("screen_name", "") or "") location = str(profile.get("location", "") or "") verified = int(profile.get("is_verified", False) or profile.get("verified", False)) default_profile = int(profile.get("default_profile", False)) default_avatar = int(profile.get("has_default_avatar", False) or profile.get("default_profile_image", False)) f2f_ratio = followers / max(friends, 1) fav2stat_ratio = favourites / max(statuses, 1) fr2fol_ratio = friends / max(followers, 1) stat2fol_ratio = statuses / max(followers, 1) has_desc = int(len(bio) > 0) has_loc = int(len(location) > 0) completeness = has_desc + has_loc + (1 - default_profile) + (1 - default_avatar) + verified sn_digits = sum(c.isdigit() for c in username) sn_digit_ratio = sn_digits / max(len(username), 1) sn_underscore = int("_" in username) tweets_per_follower = statuses / max(followers, 1) tpd_per_follower = tweets_per_day / max(followers, 1) bio_urls = len(re.findall(r"http|www\.|\.com|\.net", bio)) bio_hashtags = bio.count("#") bio_mentions = bio.count("@") bio_words = len(bio.split()) if bio else 0 news_pattern = r"\b(?:news|breaking|daily|magazine|journal|times|herald|tribune|gazette|broadcast|media|press|reporter|journalist|editor|anchor|correspondent|coverage|headlines|report)\b" org_pattern = r"\b(?:official|corp|inc\.?|llc|ltd|company|brand|store|shop|support|customer|service|team|foundation|organisation|organization|ngo|charity)\b" bio_lower = bio.lower() bio_has_news = int(bool(re.search(news_pattern, bio_lower))) bio_has_org = int(bool(re.search(org_pattern, bio_lower))) bio_likely_org = int((bio_has_news or bio_has_org) and followers > 1000 and age > 365) is_established = int(bool(verified) and followers > 10000 and age > 365) log_followers = float(np.log1p(followers)) log_friends = float(np.log1p(friends)) log_statuses = float(np.log1p(statuses)) log_favourites = float(np.log1p(favourites)) log_tpf = float(np.log1p(tweets_per_follower)) log_f2f = float(np.log1p(f2f_ratio)) return [ followers, friends, statuses, favourites, age, tweets_per_day, log_followers, log_friends, log_statuses, log_favourites, log_tpf, log_f2f, f2f_ratio, fav2stat_ratio, fr2fol_ratio, stat2fol_ratio, verified, default_profile, default_avatar, has_desc, has_loc, completeness, len(bio), len(username), sn_digits, sn_digit_ratio, sn_underscore, tweets_per_follower, tpd_per_follower, bio_urls, bio_hashtags, bio_mentions, bio_words, bio_has_news, bio_has_org, bio_likely_org, is_established, ] feature_descriptions = { "followers_count": "total followers", "friends_count": "total accounts followed", "statuses_count": "total tweets posted", "favourites_count": "total likes given", "account_age_days": "how long the account has existed", "average_tweets_per_day": "tweets posted per day on average", "log_followers_count": "follower count (log scale)", "log_friends_count": "following count (log scale)", "log_statuses_count": "tweet count (log scale)", "log_favourites_count": "likes given (log scale)", "log_tweets_per_follower": "tweets per follower (log scale)", "log_followers_to_friends_ratio": "follower-to-following balance (log scale)", "followers_to_friends_ratio": "how many followers per account followed", "favourites_to_statuses_ratio": "likes given per tweet posted", "friends_to_followers_ratio": "how many followed per follower", "statuses_to_followers_ratio": "tweets per follower", "verified": "has the verified blue checkmark", "default_profile": "still using the default profile theme", "default_profile_image": "still using the default avatar", "has_description": "has filled in a bio", "has_location": "has filled in a location", "profile_completeness": "how many profile fields are filled in", "description_length": "length of the bio", "screen_name_length": "length of the username", "screen_name_digits": "number of digits in the username", "screen_name_digit_ratio": "fraction of the username that is digits", "screen_name_has_underscore": "username contains an underscore", "tweets_per_follower": "tweets posted per follower", "tweets_per_day_per_follower": "tweets per day relative to followers", "bio_url_count": "URLs in the bio", "bio_hashtag_count": "hashtags in the bio", "bio_mention_count": "mentions in the bio", "bio_word_count": "words in the bio", "bio_has_news_keywords": "bio mentions news or journalism", "bio_has_org_keywords": "bio mentions an organisation or brand", "bio_likely_organisation": "bio plus reach suggests a real organisation", "is_established_account": "verified, large following, account older than one year", } def format_feature_value(name, value): if name == "verified": return "yes" if value > 0.5 else "no" if name in ("default_profile", "default_profile_image", "has_description", "has_location", "screen_name_has_underscore", "bio_has_news_keywords", "bio_has_org_keywords", "bio_likely_organisation", "is_established_account"): return "yes" if value > 0.5 else "no" if name == "account_age_days": years = value / 365.0 if years >= 1: return f"{years:.1f} yrs" return f"{int(value)} days" if name in ("followers_count", "friends_count", "statuses_count", "favourites_count"): if value >= 1_000_000: return f"{value/1_000_000:.1f}M" if value >= 1_000: return f"{value/1_000:.1f}K" return str(int(value)) if name == "average_tweets_per_day": return f"{value:.1f}/day" if name == "profile_completeness": return f"{int(value)}/5" if name == "screen_name_length": return f"{int(value)} chars" if name.startswith("log_"): return f"{value:.2f}" if "ratio" in name: return f"{value:.2f}" if isinstance(value, float): return f"{value:.1f}" return str(value) def compute_contributions(numeric_arr, raw_numeric): if _model_info.get("model_type") != "xgboost": return None import xgboost as xgb booster = _model.get_booster() dmatrix = xgb.DMatrix(numeric_arr.reshape(1, -1)) contribs = booster.predict(dmatrix, pred_contribs=True)[0] feat_contribs = contribs[:-1] indexed = sorted(enumerate(feat_contribs), key=lambda x: abs(x[1]), reverse=True) total_abs = sum(abs(c) for _, c in indexed if abs(c) >= 0.01) toward_bot, toward_human = [], [] for idx, contrib in indexed: if abs(contrib) < 0.01: continue if len(toward_bot) >= 4 and len(toward_human) >= 4: break name = numeric_features[idx] entry = { "feature": name, "description": feature_descriptions.get(name, name.replace("_", " ")), "value": format_feature_value(name, float(raw_numeric[idx])), "contribution": round(float(contrib), 3), "percentage": round(float(abs(contrib) / max(total_abs, 0.001)) * 100, 1), } if contrib > 0 and len(toward_bot) < 4: toward_bot.append(entry) elif contrib < 0 and len(toward_human) < 4: toward_human.append(entry) return {"toward_bot": toward_bot, "toward_human": toward_human} def generate_signals(profile, score): signals = [] followers = int(profile.get("followers_count", 0)) following = int(profile.get("following_count", 0) or profile.get("friends_count", 0)) tweets = int(profile.get("tweet_count", 0) or profile.get("statuses_count", 0)) age = max(int(profile.get("account_age_days", 365)), 1) if followers / max(following, 1) < 0.1 and following > 100: signals.append("Very low follower-to-following ratio") if age < 30: signals.append("Account is less than 30 days old") if tweets / age > 50: signals.append("Extremely high tweet frequency") if profile.get("has_default_avatar", False) or profile.get("default_profile_image", False): signals.append("Using default profile image") if followers < 5 and following > 500: signals.append("Mass-following with few followers") if len(str(profile.get("bio", "") or "")) < 5: signals.append("Empty or very short bio") if not signals and score >= 70: signals.append("Text patterns indicate automated content") if not signals: signals.append("No strong bot signals detected") return signals def predict(profile): load_model() raw_numeric = extract_numeric(profile) numeric_arr = np.array(raw_numeric, dtype=np.float32) if _numeric_mean is not None and _numeric_std is not None: numeric_arr = (numeric_arr - _numeric_mean) / _numeric_std name = _model_info["model_name"] model_type = _model_info.get("model_type", "neural") if model_type == "xgboost": bot_prob = float(_model.predict_proba(numeric_arr.reshape(1, -1))[0, 1]) else: text = prepare_text(profile) numeric = torch.tensor([numeric_arr], dtype=torch.float32, device=device) with torch.no_grad(): tokens = _tokenizer.tokenize_batch([text], max_len=max_seq_len).to(device) logits = _model(input_ids=tokens, numeric=numeric) bot_prob = torch.sigmoid(logits.squeeze()).item() raw_followers, raw_age = raw_numeric[0], raw_numeric[4] raw_verified, raw_likely_org = raw_numeric[16], raw_numeric[35] override_applied = None if raw_likely_org and raw_verified and raw_age > 365 and raw_followers > 10_000: capped = max(0.0, _threshold - 0.15) if bot_prob > capped: override_applied = "news_org" bot_prob = min(bot_prob, capped) score = int(round(bot_prob * 100)) margin = 0.18 if raw_age < 60 else 0.1 delta = bot_prob - _threshold if abs(delta) <= margin: label = "uncertain" elif delta > 0: label = "bot" else: label = "human" return { "username": profile.get("username", ""), "bot_probability": round(bot_prob, 4), "bot_score": score, "label": label, "confidence": "high" if abs(delta) > 0.3 else ("medium" if abs(delta) > 0.15 else "low"), "signals": generate_signals(profile, score), "contributions": compute_contributions(numeric_arr, raw_numeric), "override_applied": override_applied, "threshold": round(_threshold, 4), "margin": round(margin, 4), } class PredictRequest(BaseModel): username: str display_name: str = "" bio: str = "" followers_count: int = 0 following_count: int = 0 tweet_count: int = 0 listed_count: int = 0 account_age_days: int = 365 recent_tweets: list[str] = [] has_default_avatar: bool = False is_verified: bool = False url: str = "" class PredictResponse(BaseModel): username: str bot_probability: float bot_score: int label: str confidence: str signals: list[str] contributions: dict | None = None override_applied: str | None = None threshold: float = 0.5 margin: float = 0.1 app = FastAPI(title="Twitter Bot Detector API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origin_regex=r"^(https://(x|twitter)\.com|chrome-extension://.*)$", allow_credentials=False, allow_methods=["POST", "GET"], allow_headers=["Content-Type"], ) @app.on_event("startup") async def startup(): try: load_model() print("[+] Model loaded") except FileNotFoundError: print("[!] No model found, train first with: python src/train.py") except Exception as e: print(f"[!] Model load failed: {e}") @app.post("/predict", response_model=PredictResponse) async def predict_endpoint(request: PredictRequest): try: return PredictResponse(**predict(request.model_dump())) except FileNotFoundError as e: raise HTTPException(status_code=503, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) class BatchRequest(BaseModel): profiles: list[PredictRequest] class BatchResponse(BaseModel): results: list[PredictResponse] @app.post("/predict_batch", response_model=BatchResponse) async def predict_batch_endpoint(request: BatchRequest): if len(request.profiles) > 50: raise HTTPException(status_code=429, detail="batch limit is 50 profiles") try: results = [PredictResponse(**predict(p.model_dump())) for p in request.profiles] return BatchResponse(results=results) except FileNotFoundError as e: raise HTTPException(status_code=503, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) class ThreadReplyRequest(BaseModel): username: str display_name: str = "" is_verified: bool = False class ThreadReplyResponse(BaseModel): username: str flag: str reasons: list[str] class ThreadReplyBatchRequest(BaseModel): replies: list[ThreadReplyRequest] class ThreadReplyBatchResponse(BaseModel): results: list[ThreadReplyResponse] def score_thread_reply(profile): username = profile.get("username", "") is_verified = profile.get("is_verified", False) if is_verified: return { "username": username, "flag": "typical", "reasons": ["verified account"], } signals = 0 reasons = [] digits = sum(c.isdigit() for c in username) if digits >= 5: signals += 2 reasons.append(f"username contains {digits} digits") elif digits >= 3: signals += 1 reasons.append(f"username contains {digits} digits") if re.search(r"\d{4,}$", username): signals += 1 reasons.append("username ends in long digit sequence") if len(username) >= 12 and digits / max(len(username), 1) > 0.3: signals += 1 reasons.append("handle is mostly digits") if re.match(r"^[a-z]+\d+$", username.lower()): signals += 1 reasons.append("handle follows auto-generated pattern") if signals >= 3: flag = "suspicious" elif signals >= 1: flag = "possibly_suspicious" else: flag = "typical" reasons = ["no obvious red flags in visible info"] return {"username": username, "flag": flag, "reasons": reasons} @app.post("/predict_thread_batch", response_model=ThreadReplyBatchResponse) async def predict_thread_batch_endpoint(request: ThreadReplyBatchRequest): if len(request.replies) > 100: raise HTTPException(status_code=429, detail="batch limit is 100 replies") results = [ThreadReplyResponse(**score_thread_reply(r.model_dump())) for r in request.replies] return ThreadReplyBatchResponse(results=results) @app.get("/health") async def health(): return { "status": "ok", "model_loaded": _model is not None, "model_name": _model_info.get("model_name", "") if _model_info else "", }