Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Optional, Dict | |
| import re | |
| import json | |
| import torch | |
| import nltk | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from nltk.corpus import stopwords | |
| from nltk.stem import PorterStemmer, WordNetLemmatizer | |
| from nltk.tokenize import word_tokenize | |
| from textblob import TextBlob | |
| # Download NLTK data | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt') | |
| nltk.download('stopwords') | |
| nltk.download('wordnet') | |
| MODEL_ID = ( | |
| os.environ.get("MODEL_ID") | |
| or os.environ.get("HF_MODEL_ID") | |
| or "Perth0603/phishing-email-mobilebert" | |
| ) | |
| app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0") | |
| # ============================================================================ | |
| # TEXT PREPROCESSING CLASS | |
| # ============================================================================ | |
| class TextPreprocessor: | |
| """NLP preprocessing for analysis and feature extraction""" | |
| def __init__(self): | |
| self.stemmer = PorterStemmer() | |
| self.lemmatizer = WordNetLemmatizer() | |
| self.stop_words = set(stopwords.words('english')) | |
| def tokenize(self, text: str) -> List[str]: | |
| """Break text into tokens""" | |
| return word_tokenize(text.lower()) | |
| def remove_stopwords(self, tokens: List[str]) -> List[str]: | |
| """Remove common stop words""" | |
| return [token for token in tokens if token.isalnum() and token not in self.stop_words] | |
| def stem(self, tokens: List[str]) -> List[str]: | |
| """Reduce tokens to stems""" | |
| return [self.stemmer.stem(token) for token in tokens] | |
| def lemmatize(self, tokens: List[str]) -> List[str]: | |
| """Reduce tokens to lemmas""" | |
| return [self.lemmatizer.lemmatize(token) for token in tokens] | |
| def sentiment_analysis(self, text: str) -> Dict: | |
| """Analyze sentiment and phishing indicators""" | |
| blob = TextBlob(text) | |
| polarity = blob.sentiment.polarity | |
| subjectivity = blob.sentiment.subjectivity | |
| phishing_indicators = { | |
| "urgent_words": bool(re.search(r'\b(urgent|immediate|act now|verify|confirm|update|click|verify account)\b', text, re.IGNORECASE)), | |
| "threat_words": bool(re.search(r'\b(suspend|limited|expire|locked|disabled|restricted)\b', text, re.IGNORECASE)), | |
| "suspicious_urls": bool(re.search(r'http\S+|www\S+', text)), | |
| "urgency_level": "HIGH" if re.search(r'\b(urgent|immediate|act now)\b', text, re.IGNORECASE) else "LOW" | |
| } | |
| return { | |
| "polarity": round(polarity, 4), | |
| "subjectivity": round(subjectivity, 4), | |
| "sentiment": "positive" if polarity > 0.1 else "negative" if polarity < -0.1 else "neutral", | |
| "is_persuasive": subjectivity > 0.5, | |
| "phishing_indicators": phishing_indicators | |
| } | |
| def preprocess(self, text: str) -> Dict: | |
| """Preprocessing for analysis""" | |
| tokens = self.tokenize(text) | |
| tokens_no_stop = self.remove_stopwords(tokens) | |
| stemmed = self.stem(tokens_no_stop) | |
| lemmatized = self.lemmatize(tokens_no_stop) | |
| sentiment = self.sentiment_analysis(text) | |
| return { | |
| "original_text": text, | |
| "tokens": tokens, | |
| "tokens_without_stopwords": tokens_no_stop, | |
| "stemmed_tokens": stemmed, | |
| "lemmatized_tokens": lemmatized, | |
| "sentiment": sentiment, | |
| "token_count": len(tokens_no_stop) | |
| } | |
| # ============================================================================ | |
| # PYDANTIC MODELS | |
| # ============================================================================ | |
| class PredictPayload(BaseModel): | |
| inputs: str | |
| include_preprocessing: bool = True | |
| class BatchPredictPayload(BaseModel): | |
| inputs: List[str] | |
| include_preprocessing: bool = True | |
| class LabeledText(BaseModel): | |
| text: str | |
| label: Optional[str] = None | |
| class EvalPayload(BaseModel): | |
| samples: List[LabeledText] | |
| # ============================================================================ | |
| # GLOBAL VARIABLES | |
| # ============================================================================ | |
| _tokenizer = None | |
| _model = None | |
| _device = "cpu" | |
| _preprocessor = None | |
| _LABEL_MAPPING = None | |
| # ============================================================================ | |
| # HELPER FUNCTIONS | |
| # ============================================================================ | |
| def _get_label_mapping(): | |
| """Get complete label mapping from model config""" | |
| global _model | |
| if _model is None: | |
| return None | |
| id2label = getattr(_model.config, "id2label", {}) or {} | |
| num_labels = int(getattr(_model.config, "num_labels", 0) or 0) | |
| print(f"[DEBUG] Raw id2label from config: {id2label}") | |
| print(f"[DEBUG] num_labels: {num_labels}") | |
| # Build complete mapping by index | |
| complete_mapping = {} | |
| for i in range(num_labels): | |
| if str(i) in id2label: | |
| complete_mapping[i] = id2label[str(i)] | |
| elif i in id2label: | |
| complete_mapping[i] = id2label[i] | |
| else: | |
| complete_mapping[i] = f"LABEL_{i}" | |
| # If incomplete, use fallback | |
| if len(complete_mapping) < num_labels: | |
| print(f"[WARNING] Incomplete mapping! Using fallback.") | |
| complete_mapping = { | |
| 0: "LEGIT", | |
| 1: "PHISH" | |
| } | |
| print(f"[DEBUG] Complete mapping applied: {complete_mapping}") | |
| return complete_mapping | |
| def _normalize_label(txt: str) -> str: | |
| """Normalize label text""" | |
| t = (str(txt) if txt is not None else "").strip().upper() | |
| if t in ("PHISHING", "PHISH", "SPAM", "1"): | |
| return "PHISH" | |
| if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM", "0"): | |
| return "LEGIT" | |
| return t | |
| def _load_model(): | |
| """Load model, tokenizer, and preprocessor""" | |
| global _tokenizer, _model, _device, _preprocessor, _LABEL_MAPPING | |
| if _tokenizer is None or _model is None: | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"\n{'='*60}") | |
| print(f"Loading model on device: {_device}") | |
| print(f"Model ID: {MODEL_ID}") | |
| print(f"{'='*60}\n") | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) | |
| _model.to(_device) | |
| _model.eval() | |
| _preprocessor = TextPreprocessor() | |
| # Get label mapping | |
| _LABEL_MAPPING = _get_label_mapping() | |
| # Warm-up | |
| with torch.no_grad(): | |
| _ = _model( | |
| **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| .to(_device) | |
| ).logits | |
| print(f"{'='*60}\n") | |
| def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]: | |
| """ | |
| Predict with correct label index mapping | |
| CRITICAL: probs[i][j] where j is the CLASS INDEX, not probability value | |
| """ | |
| _load_model() | |
| if not texts: | |
| return [] | |
| # Get preprocessing info | |
| preprocessing_info = None | |
| if include_preprocessing: | |
| preprocessing_info = [_preprocessor.preprocess(text) for text in texts] | |
| # Tokenize | |
| enc = _tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| ) | |
| enc = {k: v.to(_device) for k, v in enc.items()} | |
| # Predict | |
| with torch.no_grad(): | |
| logits = _model(**enc).logits | |
| probs = torch.softmax(logits, dim=-1) | |
| num_labels = probs.shape[-1] | |
| print(f"\n[DEBUG] num_labels from probs shape: {num_labels}") | |
| outputs: List[Dict] = [] | |
| for text_idx in range(probs.shape[0]): | |
| p = probs[text_idx] # Get probabilities for this text: shape [num_labels] | |
| # Create probability breakdown for ALL classes | |
| prob_breakdown = {} | |
| all_probs_list = [] | |
| for class_idx in range(num_labels): | |
| class_prob = float(p[class_idx].item()) | |
| class_label = _LABEL_MAPPING.get(class_idx, f"CLASS_{class_idx}") | |
| prob_breakdown[class_label] = round(class_prob, 4) | |
| all_probs_list.append(class_prob) | |
| print(f"[DEBUG] Class {class_idx} ({class_label}): {round(class_prob, 4)}") | |
| # Get argmax index | |
| predicted_idx = int(torch.argmax(p).item()) | |
| predicted_label_raw = _LABEL_MAPPING.get(predicted_idx, f"CLASS_{predicted_idx}") | |
| predicted_label_norm = _normalize_label(predicted_label_raw) | |
| predicted_prob = float(p[predicted_idx].item()) | |
| print(f"[DEBUG] ARGMAX: index={predicted_idx}, label={predicted_label_raw}, prob={round(predicted_prob, 4)}") | |
| print(f"[DEBUG] Normalized label: {predicted_label_norm}") | |
| output = { | |
| "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx], | |
| "predicted_class_index": predicted_idx, | |
| "label": predicted_label_norm, | |
| "raw_label": predicted_label_raw, | |
| "is_phish": predicted_label_norm == "PHISH", | |
| "score": round(predicted_prob, 4), | |
| "confidence": round(predicted_prob * 100, 2), | |
| "probs_by_class": prob_breakdown, | |
| "all_probs_raw": [round(p_val, 4) for p_val in all_probs_list], | |
| } | |
| if include_preprocessing and preprocessing_info: | |
| output["preprocessing"] = preprocessing_info[text_idx] | |
| outputs.append(output) | |
| print(f"\n") | |
| return outputs | |
| # ============================================================================ | |
| # API ENDPOINTS | |
| # ============================================================================ | |
| def root(): | |
| """Root endpoint""" | |
| _load_model() | |
| return { | |
| "status": "ok", | |
| "model": MODEL_ID, | |
| "device": _device, | |
| "label_mapping": _LABEL_MAPPING, | |
| } | |
| def debug_labels(): | |
| """View complete model configuration""" | |
| _load_model() | |
| id2label_raw = getattr(_model.config, "id2label", {}) or {} | |
| label2id_raw = getattr(_model.config, "label2id", {}) or {} | |
| num_labels = int(getattr(_model.config, "num_labels", 0) or 0) | |
| return { | |
| "status": "ok", | |
| "model_config_id2label": id2label_raw, | |
| "model_config_label2id": label2id_raw, | |
| "model_config_num_labels": num_labels, | |
| "applied_mapping": _LABEL_MAPPING, | |
| "device": _device, | |
| "note": "applied_mapping is what gets used for predictions" | |
| } | |
| def debug_preprocessing(payload: PredictPayload): | |
| """Debug preprocessing""" | |
| try: | |
| _load_model() | |
| preprocessing = _preprocessor.preprocess(payload.inputs) | |
| return { | |
| "status": "ok", | |
| "preprocessing": preprocessing | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error: {e}") | |
| def predict(payload: PredictPayload): | |
| """Single prediction""" | |
| try: | |
| res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing) | |
| return res[0] | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error: {e}") | |
| def predict_batch(payload: BatchPredictPayload): | |
| """Batch predictions""" | |
| try: | |
| return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error: {e}") | |
| def evaluate(payload: EvalPayload): | |
| """Evaluate on labeled samples""" | |
| try: | |
| texts = [s.text for s in payload.samples] | |
| gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples] | |
| preds = _predict_texts(texts, include_preprocessing=False) | |
| total = len(preds) | |
| correct = 0 | |
| per_class: Dict[str, Dict[str, int]] = {} | |
| for gt, pr in zip(gts, preds): | |
| pred_label = pr["label"] | |
| if gt is not None: | |
| correct += int(gt == pred_label) | |
| per_class.setdefault(gt, {"tp": 0, "count": 0}) | |
| per_class[gt]["count"] += 1 | |
| if gt == pred_label: | |
| per_class[gt]["tp"] += 1 | |
| has_gts = any(gt is not None for gt in gts) | |
| acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None | |
| return { | |
| "accuracy": round(acc, 4) if acc else None, | |
| "total": total, | |
| "correct": correct, | |
| "predictions": preds, | |
| "per_class": per_class, | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error: {e}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |