sppamemail / app.py
haroon103's picture
Update app.py
1085ade verified
"""
FYP4 SPAM DETECTION API - Loads model from Hugging Face Hub
"""
import os
import re
import io
import torch
import torch.nn as nn
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import PyPDF2
import pdfplumber
from PIL import Image
from transformers import DebertaV2Model, DebertaV2Tokenizer
# ================================
# CONFIGURATION
# ================================
class Config:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TEXT_MODEL = 'microsoft/deberta-v3-base'
# πŸ”₯ CHANGE THIS TO YOUR MODEL REPO! πŸ”₯
MODEL_REPO = "haroon103/fyp4-spam-model" # ← Change haroon103 to YOUR username
TEXT_HIDDEN_DIM = 768
FUSION_DIM = 512
NUM_CLASSES = 2
DROPOUT = 0.3
MAX_TEXT_LENGTH = 256
config = Config()
# ================================
# MODEL ARCHITECTURES
# ================================
class DeBERTaTextEncoder(nn.Module):
def __init__(self, dropout=0.3):
super(DeBERTaTextEncoder, self).__init__()
self.deberta = DebertaV2Model.from_pretrained(config.TEXT_MODEL)
self.projection = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(config.TEXT_HIDDEN_DIM, config.FUSION_DIM),
nn.LayerNorm(config.FUSION_DIM),
nn.GELU()
)
def forward(self, input_ids, attention_mask):
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state[:, 0, :]
return self.projection(pooled)
class TextSpamClassifier(nn.Module):
def __init__(self, dropout=0.3):
super(TextSpamClassifier, self).__init__()
self.text_encoder = DeBERTaTextEncoder(dropout)
self.classifier = nn.Sequential(
nn.Linear(config.FUSION_DIM, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.LayerNorm(128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, config.NUM_CLASSES)
)
def forward(self, input_ids, attention_mask):
features = self.text_encoder(input_ids, attention_mask)
return self.classifier(features)
# ================================
# PDF EXTRACTION
# ================================
class PDFExtractor:
@staticmethod
def extract_text_from_pdf(pdf_bytes):
email_data = {'subject': '', 'sender': '', 'body': '', 'full_text': ''}
try:
pdf_file = io.BytesIO(pdf_bytes)
with pdfplumber.open(pdf_file) as pdf:
full_text = ""
for page in pdf.pages:
text = page.extract_text()
if text:
full_text += text + "\n"
email_data['full_text'] = full_text
patterns = {
'subject': [r'Subject:\s*(.+)', r'SUBJECT:\s*(.+)'],
'sender': [r'From:\s*(.+)', r'FROM:\s*(.+)']
}
for field, pattern_list in patterns.items():
for pattern in pattern_list:
match = re.search(pattern, full_text, re.IGNORECASE)
if match:
email_data[field] = match.group(1).strip()[:100]
break
body_match = re.search(r'(?:Subject|Date|From|To):.+?\n\n(.+)', full_text, re.DOTALL | re.IGNORECASE)
if body_match:
email_data['body'] = body_match.group(1).strip()
else:
email_data['body'] = full_text
return email_data
except:
try:
pdf_file = io.BytesIO(pdf_bytes)
pdf_reader = PyPDF2.PdfReader(pdf_file)
full_text = ""
for page in pdf_reader.pages:
text = page.extract_text()
if text:
full_text += text + "\n"
email_data['full_text'] = full_text
email_data['body'] = full_text
return email_data
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error: {str(e)}")
# ================================
# TEXT PREPROCESSING
# ================================
def preprocess_text(text):
text = str(text).lower()
text = re.sub(r'http\S+|www\.\S+', '[URL]', text)
text = re.sub(r'\S+@\S+', '[EMAIL]', text)
text = re.sub(r'\d+', '[NUM]', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# ================================
# SPAM DETECTOR
# ================================
class SpamDetector:
def __init__(self):
self.device = config.DEVICE
self.tokenizer = DebertaV2Tokenizer.from_pretrained(config.TEXT_MODEL)
self.model = None
try:
print(f"πŸ“₯ Downloading model from {config.MODEL_REPO}...")
model_path = hf_hub_download(
repo_id=config.MODEL_REPO,
filename="model.safetensors"
)
print(f"βœ… Downloaded to: {model_path}")
print("πŸ”§ Loading model...")
self.model = TextSpamClassifier(dropout=config.DROPOUT).to(self.device)
state_dict = load_file(model_path)
self.model.load_state_dict(state_dict, strict=False)
self.model.eval()
print("βœ… Model ready!")
except Exception as e:
print(f"❌ Error: {e}")
self.model = None
def predict_text(self, text):
if not self.model:
raise HTTPException(status_code=503, detail="Model not loaded")
encoding = self.tokenizer(
preprocess_text(text),
add_special_tokens=True,
max_length=config.MAX_TEXT_LENGTH,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids, attention_mask)
probs = torch.softmax(outputs, dim=1)
predicted = torch.argmax(probs, dim=1)
return {
'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE',
'confidence': float(probs[0, predicted.item()].item() * 100),
'spam_probability': float(probs[0, 1].item() * 100),
'ham_probability': float(probs[0, 0].item() * 100)
}
# ================================
# FASTAPI APPLICATION
# ================================
app = FastAPI(
title="FYP4 Spam Detection API",
description="Email spam detection using DeBERTa",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
detector = None
@app.on_event("startup")
async def startup_event():
global detector
print("πŸš€ Starting API...")
detector = SpamDetector()
if detector.model:
print("βœ… API Ready!")
else:
print("⚠️ Model failed to load")
class TextRequest(BaseModel):
text: str
@app.get("/")
async def root():
return {
"message": "🎯 FYP4 Spam Detection API",
"status": "running",
"model": config.MODEL_REPO,
"endpoints": {
"POST /predict/text": "Detect spam from text",
"POST /predict/pdf": "Detect spam from PDF",
"GET /health": "Check API health"
}
}
@app.get("/health")
async def health():
return {
"status": "healthy" if detector and detector.model else "unhealthy",
"model_loaded": detector.model is not None if detector else False,
"device": str(config.DEVICE)
}
@app.post("/predict/text")
async def predict_text(request: TextRequest):
if not detector or not detector.model:
raise HTTPException(status_code=503, detail="Model not available")
result = detector.predict_text(request.text)
return result
@app.post("/predict/pdf")
async def predict_pdf(file: UploadFile = File(...)):
if not file.filename.endswith('.pdf'):
raise HTTPException(status_code=400, detail="Must be PDF file")
if not detector or not detector.model:
raise HTTPException(status_code=503, detail="Model not available")
pdf_bytes = await file.read()
email_data = PDFExtractor.extract_text_from_pdf(pdf_bytes)
full_text = f"{email_data['subject']}\n\n{email_data['body']}"
result = detector.predict_text(full_text)
return {
'email_data': email_data,
'prediction': result
}