Spaces:
Running
Running
| """ | |
| FYP4 SPAM DETECTION API - Loads model from Hugging Face Hub | |
| """ | |
| import os | |
| import re | |
| import io | |
| import torch | |
| import torch.nn as nn | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| import PyPDF2 | |
| import pdfplumber | |
| from PIL import Image | |
| from transformers import DebertaV2Model, DebertaV2Tokenizer | |
| # ================================ | |
| # CONFIGURATION | |
| # ================================ | |
| class Config: | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| TEXT_MODEL = 'microsoft/deberta-v3-base' | |
| # π₯ CHANGE THIS TO YOUR MODEL REPO! π₯ | |
| MODEL_REPO = "haroon103/fyp4-spam-model" # β Change haroon103 to YOUR username | |
| TEXT_HIDDEN_DIM = 768 | |
| FUSION_DIM = 512 | |
| NUM_CLASSES = 2 | |
| DROPOUT = 0.3 | |
| MAX_TEXT_LENGTH = 256 | |
| config = Config() | |
| # ================================ | |
| # MODEL ARCHITECTURES | |
| # ================================ | |
| class DeBERTaTextEncoder(nn.Module): | |
| def __init__(self, dropout=0.3): | |
| super(DeBERTaTextEncoder, self).__init__() | |
| self.deberta = DebertaV2Model.from_pretrained(config.TEXT_MODEL) | |
| self.projection = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(config.TEXT_HIDDEN_DIM, config.FUSION_DIM), | |
| nn.LayerNorm(config.FUSION_DIM), | |
| nn.GELU() | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = outputs.last_hidden_state[:, 0, :] | |
| return self.projection(pooled) | |
| class TextSpamClassifier(nn.Module): | |
| def __init__(self, dropout=0.3): | |
| super(TextSpamClassifier, self).__init__() | |
| self.text_encoder = DeBERTaTextEncoder(dropout) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(config.FUSION_DIM, 256), | |
| nn.LayerNorm(256), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 128), | |
| nn.LayerNorm(128), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, config.NUM_CLASSES) | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| features = self.text_encoder(input_ids, attention_mask) | |
| return self.classifier(features) | |
| # ================================ | |
| # PDF EXTRACTION | |
| # ================================ | |
| class PDFExtractor: | |
| def extract_text_from_pdf(pdf_bytes): | |
| email_data = {'subject': '', 'sender': '', 'body': '', 'full_text': ''} | |
| try: | |
| pdf_file = io.BytesIO(pdf_bytes) | |
| with pdfplumber.open(pdf_file) as pdf: | |
| full_text = "" | |
| for page in pdf.pages: | |
| text = page.extract_text() | |
| if text: | |
| full_text += text + "\n" | |
| email_data['full_text'] = full_text | |
| patterns = { | |
| 'subject': [r'Subject:\s*(.+)', r'SUBJECT:\s*(.+)'], | |
| 'sender': [r'From:\s*(.+)', r'FROM:\s*(.+)'] | |
| } | |
| for field, pattern_list in patterns.items(): | |
| for pattern in pattern_list: | |
| match = re.search(pattern, full_text, re.IGNORECASE) | |
| if match: | |
| email_data[field] = match.group(1).strip()[:100] | |
| break | |
| body_match = re.search(r'(?:Subject|Date|From|To):.+?\n\n(.+)', full_text, re.DOTALL | re.IGNORECASE) | |
| if body_match: | |
| email_data['body'] = body_match.group(1).strip() | |
| else: | |
| email_data['body'] = full_text | |
| return email_data | |
| except: | |
| try: | |
| pdf_file = io.BytesIO(pdf_bytes) | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| full_text = "" | |
| for page in pdf_reader.pages: | |
| text = page.extract_text() | |
| if text: | |
| full_text += text + "\n" | |
| email_data['full_text'] = full_text | |
| email_data['body'] = full_text | |
| return email_data | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error: {str(e)}") | |
| # ================================ | |
| # TEXT PREPROCESSING | |
| # ================================ | |
| def preprocess_text(text): | |
| text = str(text).lower() | |
| text = re.sub(r'http\S+|www\.\S+', '[URL]', text) | |
| text = re.sub(r'\S+@\S+', '[EMAIL]', text) | |
| text = re.sub(r'\d+', '[NUM]', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| # ================================ | |
| # SPAM DETECTOR | |
| # ================================ | |
| class SpamDetector: | |
| def __init__(self): | |
| self.device = config.DEVICE | |
| self.tokenizer = DebertaV2Tokenizer.from_pretrained(config.TEXT_MODEL) | |
| self.model = None | |
| try: | |
| print(f"π₯ Downloading model from {config.MODEL_REPO}...") | |
| model_path = hf_hub_download( | |
| repo_id=config.MODEL_REPO, | |
| filename="model.safetensors" | |
| ) | |
| print(f"β Downloaded to: {model_path}") | |
| print("π§ Loading model...") | |
| self.model = TextSpamClassifier(dropout=config.DROPOUT).to(self.device) | |
| state_dict = load_file(model_path) | |
| self.model.load_state_dict(state_dict, strict=False) | |
| self.model.eval() | |
| print("β Model ready!") | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| self.model = None | |
| def predict_text(self, text): | |
| if not self.model: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| encoding = self.tokenizer( | |
| preprocess_text(text), | |
| add_special_tokens=True, | |
| max_length=config.MAX_TEXT_LENGTH, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(self.device) | |
| attention_mask = encoding['attention_mask'].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids, attention_mask) | |
| probs = torch.softmax(outputs, dim=1) | |
| predicted = torch.argmax(probs, dim=1) | |
| return { | |
| 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', | |
| 'confidence': float(probs[0, predicted.item()].item() * 100), | |
| 'spam_probability': float(probs[0, 1].item() * 100), | |
| 'ham_probability': float(probs[0, 0].item() * 100) | |
| } | |
| # ================================ | |
| # FASTAPI APPLICATION | |
| # ================================ | |
| app = FastAPI( | |
| title="FYP4 Spam Detection API", | |
| description="Email spam detection using DeBERTa", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| detector = None | |
| async def startup_event(): | |
| global detector | |
| print("π Starting API...") | |
| detector = SpamDetector() | |
| if detector.model: | |
| print("β API Ready!") | |
| else: | |
| print("β οΈ Model failed to load") | |
| class TextRequest(BaseModel): | |
| text: str | |
| async def root(): | |
| return { | |
| "message": "π― FYP4 Spam Detection API", | |
| "status": "running", | |
| "model": config.MODEL_REPO, | |
| "endpoints": { | |
| "POST /predict/text": "Detect spam from text", | |
| "POST /predict/pdf": "Detect spam from PDF", | |
| "GET /health": "Check API health" | |
| } | |
| } | |
| async def health(): | |
| return { | |
| "status": "healthy" if detector and detector.model else "unhealthy", | |
| "model_loaded": detector.model is not None if detector else False, | |
| "device": str(config.DEVICE) | |
| } | |
| async def predict_text(request: TextRequest): | |
| if not detector or not detector.model: | |
| raise HTTPException(status_code=503, detail="Model not available") | |
| result = detector.predict_text(request.text) | |
| return result | |
| async def predict_pdf(file: UploadFile = File(...)): | |
| if not file.filename.endswith('.pdf'): | |
| raise HTTPException(status_code=400, detail="Must be PDF file") | |
| if not detector or not detector.model: | |
| raise HTTPException(status_code=503, detail="Model not available") | |
| pdf_bytes = await file.read() | |
| email_data = PDFExtractor.extract_text_from_pdf(pdf_bytes) | |
| full_text = f"{email_data['subject']}\n\n{email_data['body']}" | |
| result = detector.predict_text(full_text) | |
| return { | |
| 'email_data': email_data, | |
| 'prediction': result | |
| } |