""" FYP4 SPAM DETECTION API - Loads model from Hugging Face Hub """ import os import re import io import torch import torch.nn as nn from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from huggingface_hub import hf_hub_download from safetensors.torch import load_file import PyPDF2 import pdfplumber from PIL import Image from transformers import DebertaV2Model, DebertaV2Tokenizer # ================================ # CONFIGURATION # ================================ class Config: DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') TEXT_MODEL = 'microsoft/deberta-v3-base' # 🔥 CHANGE THIS TO YOUR MODEL REPO! 🔥 MODEL_REPO = "haroon103/fyp4-spam-model" # ← Change haroon103 to YOUR username TEXT_HIDDEN_DIM = 768 FUSION_DIM = 512 NUM_CLASSES = 2 DROPOUT = 0.3 MAX_TEXT_LENGTH = 256 config = Config() # ================================ # MODEL ARCHITECTURES # ================================ class DeBERTaTextEncoder(nn.Module): def __init__(self, dropout=0.3): super(DeBERTaTextEncoder, self).__init__() self.deberta = DebertaV2Model.from_pretrained(config.TEXT_MODEL) self.projection = nn.Sequential( nn.Dropout(dropout), nn.Linear(config.TEXT_HIDDEN_DIM, config.FUSION_DIM), nn.LayerNorm(config.FUSION_DIM), nn.GELU() ) def forward(self, input_ids, attention_mask): outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state[:, 0, :] return self.projection(pooled) class TextSpamClassifier(nn.Module): def __init__(self, dropout=0.3): super(TextSpamClassifier, self).__init__() self.text_encoder = DeBERTaTextEncoder(dropout) self.classifier = nn.Sequential( nn.Linear(config.FUSION_DIM, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, config.NUM_CLASSES) ) def forward(self, input_ids, attention_mask): features = self.text_encoder(input_ids, attention_mask) return self.classifier(features) # ================================ # PDF EXTRACTION # ================================ class PDFExtractor: @staticmethod def extract_text_from_pdf(pdf_bytes): email_data = {'subject': '', 'sender': '', 'body': '', 'full_text': ''} try: pdf_file = io.BytesIO(pdf_bytes) with pdfplumber.open(pdf_file) as pdf: full_text = "" for page in pdf.pages: text = page.extract_text() if text: full_text += text + "\n" email_data['full_text'] = full_text patterns = { 'subject': [r'Subject:\s*(.+)', r'SUBJECT:\s*(.+)'], 'sender': [r'From:\s*(.+)', r'FROM:\s*(.+)'] } for field, pattern_list in patterns.items(): for pattern in pattern_list: match = re.search(pattern, full_text, re.IGNORECASE) if match: email_data[field] = match.group(1).strip()[:100] break body_match = re.search(r'(?:Subject|Date|From|To):.+?\n\n(.+)', full_text, re.DOTALL | re.IGNORECASE) if body_match: email_data['body'] = body_match.group(1).strip() else: email_data['body'] = full_text return email_data except: try: pdf_file = io.BytesIO(pdf_bytes) pdf_reader = PyPDF2.PdfReader(pdf_file) full_text = "" for page in pdf_reader.pages: text = page.extract_text() if text: full_text += text + "\n" email_data['full_text'] = full_text email_data['body'] = full_text return email_data except Exception as e: raise HTTPException(status_code=400, detail=f"Error: {str(e)}") # ================================ # TEXT PREPROCESSING # ================================ def preprocess_text(text): text = str(text).lower() text = re.sub(r'http\S+|www\.\S+', '[URL]', text) text = re.sub(r'\S+@\S+', '[EMAIL]', text) text = re.sub(r'\d+', '[NUM]', text) text = re.sub(r'\s+', ' ', text).strip() return text # ================================ # SPAM DETECTOR # ================================ class SpamDetector: def __init__(self): self.device = config.DEVICE self.tokenizer = DebertaV2Tokenizer.from_pretrained(config.TEXT_MODEL) self.model = None try: print(f"📥 Downloading model from {config.MODEL_REPO}...") model_path = hf_hub_download( repo_id=config.MODEL_REPO, filename="model.safetensors" ) print(f"✅ Downloaded to: {model_path}") print("🔧 Loading model...") self.model = TextSpamClassifier(dropout=config.DROPOUT).to(self.device) state_dict = load_file(model_path) self.model.load_state_dict(state_dict, strict=False) self.model.eval() print("✅ Model ready!") except Exception as e: print(f"❌ Error: {e}") self.model = None def predict_text(self, text): if not self.model: raise HTTPException(status_code=503, detail="Model not loaded") encoding = self.tokenizer( preprocess_text(text), add_special_tokens=True, max_length=config.MAX_TEXT_LENGTH, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(self.device) attention_mask = encoding['attention_mask'].to(self.device) with torch.no_grad(): outputs = self.model(input_ids, attention_mask) probs = torch.softmax(outputs, dim=1) predicted = torch.argmax(probs, dim=1) return { 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', 'confidence': float(probs[0, predicted.item()].item() * 100), 'spam_probability': float(probs[0, 1].item() * 100), 'ham_probability': float(probs[0, 0].item() * 100) } # ================================ # FASTAPI APPLICATION # ================================ app = FastAPI( title="FYP4 Spam Detection API", description="Email spam detection using DeBERTa", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) detector = None @app.on_event("startup") async def startup_event(): global detector print("🚀 Starting API...") detector = SpamDetector() if detector.model: print("✅ API Ready!") else: print("⚠️ Model failed to load") class TextRequest(BaseModel): text: str @app.get("/") async def root(): return { "message": "🎯 FYP4 Spam Detection API", "status": "running", "model": config.MODEL_REPO, "endpoints": { "POST /predict/text": "Detect spam from text", "POST /predict/pdf": "Detect spam from PDF", "GET /health": "Check API health" } } @app.get("/health") async def health(): return { "status": "healthy" if detector and detector.model else "unhealthy", "model_loaded": detector.model is not None if detector else False, "device": str(config.DEVICE) } @app.post("/predict/text") async def predict_text(request: TextRequest): if not detector or not detector.model: raise HTTPException(status_code=503, detail="Model not available") result = detector.predict_text(request.text) return result @app.post("/predict/pdf") async def predict_pdf(file: UploadFile = File(...)): if not file.filename.endswith('.pdf'): raise HTTPException(status_code=400, detail="Must be PDF file") if not detector or not detector.model: raise HTTPException(status_code=503, detail="Model not available") pdf_bytes = await file.read() email_data = PDFExtractor.extract_text_from_pdf(pdf_bytes) full_text = f"{email_data['subject']}\n\n{email_data['body']}" result = detector.predict_text(full_text) return { 'email_data': email_data, 'prediction': result }