Spamforensics / app.py
haroon103's picture
Rename app.py.py to app.py
e202974 verified
"""
FYP4 SPAM DETECTION API
FastAPI application for email spam detection using DeBERTa and ViT models
"""
import os
import re
import io
import torch
import torch.nn as nn
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import PyPDF2
import pdfplumber
from PIL import Image
from transformers import (
DebertaV2Model,
DebertaV2Tokenizer,
ViTModel,
ViTImageProcessor
)
# ================================
# CONFIGURATION
# ================================
class Config:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TEXT_MODEL = 'microsoft/deberta-v3-base'
IMAGE_MODEL = 'google/vit-base-patch16-224-in21k'
TEXT_HIDDEN_DIM = 768
IMAGE_HIDDEN_DIM = 768
FUSION_DIM = 512
NUM_CLASSES = 2
DROPOUT = 0.3
MAX_TEXT_LENGTH = 256
IMG_SIZE = 224
config = Config()
# ================================
# MODEL ARCHITECTURES
# ================================
class DeBERTaTextEncoder(nn.Module):
def __init__(self, dropout=0.3):
super(DeBERTaTextEncoder, self).__init__()
self.deberta = DebertaV2Model.from_pretrained(config.TEXT_MODEL)
self.projection = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(config.TEXT_HIDDEN_DIM, config.FUSION_DIM),
nn.LayerNorm(config.FUSION_DIM),
nn.GELU()
)
def forward(self, input_ids, attention_mask):
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state[:, 0, :]
return self.projection(pooled)
class ViTImageEncoder(nn.Module):
def __init__(self, dropout=0.3):
super(ViTImageEncoder, self).__init__()
self.vit = ViTModel.from_pretrained(config.IMAGE_MODEL)
self.projection = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(config.IMAGE_HIDDEN_DIM, config.FUSION_DIM),
nn.LayerNorm(config.FUSION_DIM),
nn.GELU()
)
def forward(self, pixel_values):
outputs = self.vit(pixel_values=pixel_values, return_dict=True)
pooled = outputs.last_hidden_state[:, 0, :]
return self.projection(pooled)
class CrossModalAttention(nn.Module):
def __init__(self, dim=512, num_heads=8, dropout=0.1):
super(CrossModalAttention, self).__init__()
self.cross_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim),
nn.Dropout(dropout)
)
def forward(self, text_features, image_features):
text_features = text_features.unsqueeze(1)
image_features = image_features.unsqueeze(1)
attn_output, _ = self.cross_attn(text_features, image_features, image_features)
fused = self.norm1(text_features + attn_output)
ffn_output = self.ffn(fused)
output = self.norm2(fused + ffn_output)
return output.squeeze(1)
class TextSpamClassifier(nn.Module):
def __init__(self, dropout=0.3):
super(TextSpamClassifier, self).__init__()
self.text_encoder = DeBERTaTextEncoder(dropout)
self.classifier = nn.Sequential(
nn.Linear(config.FUSION_DIM, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.LayerNorm(128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, config.NUM_CLASSES)
)
def forward(self, input_ids, attention_mask):
features = self.text_encoder(input_ids, attention_mask)
return self.classifier(features)
class ImageSpamClassifier(nn.Module):
def __init__(self, dropout=0.3):
super(ImageSpamClassifier, self).__init__()
self.image_encoder = ViTImageEncoder(dropout)
self.classifier = nn.Sequential(
nn.Linear(config.FUSION_DIM, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.LayerNorm(128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, config.NUM_CLASSES)
)
def forward(self, pixel_values):
features = self.image_encoder(pixel_values)
return self.classifier(features)
class FusionSpamClassifier(nn.Module):
def __init__(self, dropout=0.3):
super(FusionSpamClassifier, self).__init__()
self.text_encoder = DeBERTaTextEncoder(dropout)
self.image_encoder = ViTImageEncoder(dropout)
self.cross_modal_fusion = CrossModalAttention(config.FUSION_DIM, num_heads=8, dropout=dropout)
self.classifier = nn.Sequential(
nn.Linear(config.FUSION_DIM, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.LayerNorm(128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, config.NUM_CLASSES)
)
def forward(self, input_ids=None, attention_mask=None, pixel_values=None):
if input_ids is not None and pixel_values is not None:
text_features = self.text_encoder(input_ids, attention_mask)
image_features = self.image_encoder(pixel_values)
fused_features = self.cross_modal_fusion(text_features, image_features)
elif input_ids is not None:
fused_features = self.text_encoder(input_ids, attention_mask)
elif pixel_values is not None:
fused_features = self.image_encoder(pixel_values)
else:
raise ValueError("At least one modality required")
return self.classifier(fused_features)
# ================================
# PDF EXTRACTION
# ================================
class PDFExtractor:
@staticmethod
def extract_text_from_pdf(pdf_bytes):
"""Extract text from PDF bytes"""
email_data = {
'subject': '',
'sender': '',
'body': '',
'full_text': ''
}
try:
pdf_file = io.BytesIO(pdf_bytes)
with pdfplumber.open(pdf_file) as pdf:
full_text = ""
for page in pdf.pages:
text = page.extract_text()
if text:
full_text += text + "\n"
email_data['full_text'] = full_text
patterns = {
'subject': [r'Subject:\s*(.+)', r'SUBJECT:\s*(.+)'],
'sender': [r'From:\s*(.+)', r'FROM:\s*(.+)']
}
for field, pattern_list in patterns.items():
for pattern in pattern_list:
match = re.search(pattern, full_text, re.IGNORECASE)
if match:
email_data[field] = match.group(1).strip()[:100]
break
body_match = re.search(r'(?:Subject|Date|From|To):.+?\n\n(.+)', full_text, re.DOTALL | re.IGNORECASE)
if body_match:
email_data['body'] = body_match.group(1).strip()
else:
email_data['body'] = full_text
return email_data
except Exception as e:
try:
pdf_file = io.BytesIO(pdf_bytes)
pdf_reader = PyPDF2.PdfReader(pdf_file)
full_text = ""
for page in pdf_reader.pages:
text = page.extract_text()
if text:
full_text += text + "\n"
email_data['full_text'] = full_text
email_data['body'] = full_text
return email_data
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error extracting text from PDF: {str(e)}")
@staticmethod
def extract_images_from_pdf(pdf_bytes):
"""Extract first image from PDF bytes"""
try:
pdf_file = io.BytesIO(pdf_bytes)
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page_num, page in enumerate(pdf_reader.pages):
if '/XObject' in page['/Resources']:
xObject = page['/Resources']['/XObject'].get_object()
for obj in xObject:
if xObject[obj]['/Subtype'] == '/Image':
try:
size = (xObject[obj]['/Width'], xObject[obj]['/Height'])
data = xObject[obj].get_data()
mode = "RGB" if xObject[obj]['/ColorSpace'] == '/DeviceRGB' else "P"
img = Image.frombytes(mode, size, data)
return img
except:
continue
except:
pass
return None
# ================================
# TEXT PREPROCESSING
# ================================
def preprocess_text(text):
"""Preprocess text for model input"""
text = str(text).lower()
text = re.sub(r'http\S+|www\.\S+', '[URL]', text)
text = re.sub(r'\S+@\S+', '[EMAIL]', text)
text = re.sub(r'\d+', '[NUM]', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# ================================
# SPAM DETECTOR
# ================================
class SpamDetector:
def __init__(self, text_model_path=None, image_model_path=None, fusion_model_path=None):
self.device = config.DEVICE
self.tokenizer = DebertaV2Tokenizer.from_pretrained(config.TEXT_MODEL)
self.image_processor = ViTImageProcessor.from_pretrained(config.IMAGE_MODEL)
self.text_model = None
self.image_model = None
self.fusion_model = None
# Load models
if text_model_path and os.path.exists(text_model_path):
print(f"Loading text model from {text_model_path}...")
self.text_model = TextSpamClassifier().to(self.device)
checkpoint = torch.load(text_model_path, map_location=self.device)
self.text_model.load_state_dict(checkpoint['model_state_dict'])
self.text_model.eval()
print("Text model loaded successfully!")
if image_model_path and os.path.exists(image_model_path):
print(f"Loading image model from {image_model_path}...")
self.image_model = ImageSpamClassifier().to(self.device)
checkpoint = torch.load(image_model_path, map_location=self.device)
self.image_model.load_state_dict(checkpoint['model_state_dict'])
self.image_model.eval()
print("Image model loaded successfully!")
if fusion_model_path and os.path.exists(fusion_model_path):
print(f"Loading fusion model from {fusion_model_path}...")
self.fusion_model = FusionSpamClassifier().to(self.device)
checkpoint = torch.load(fusion_model_path, map_location=self.device)
self.fusion_model.load_state_dict(checkpoint['model_state_dict'])
self.fusion_model.eval()
print("Fusion model loaded successfully!")
def predict_text(self, text):
if not self.text_model:
return None
encoding = self.tokenizer(
preprocess_text(text),
add_special_tokens=True,
max_length=config.MAX_TEXT_LENGTH,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
with torch.no_grad():
outputs = self.text_model(input_ids, attention_mask)
probs = torch.softmax(outputs, dim=1)
predicted = torch.argmax(probs, dim=1)
return {
'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE',
'confidence': float(probs[0, predicted.item()].item() * 100),
'spam_probability': float(probs[0, 1].item() * 100),
'ham_probability': float(probs[0, 0].item() * 100)
}
def predict_image(self, image):
if not self.image_model or image is None:
return None
try:
inputs = self.image_processor(images=image, return_tensors='pt')
pixel_values = inputs['pixel_values'].to(self.device)
with torch.no_grad():
outputs = self.image_model(pixel_values)
probs = torch.softmax(outputs, dim=1)
predicted = torch.argmax(probs, dim=1)
return {
'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE',
'confidence': float(probs[0, predicted.item()].item() * 100),
'spam_probability': float(probs[0, 1].item() * 100),
'ham_probability': float(probs[0, 0].item() * 100)
}
except Exception as e:
return {'error': str(e)}
def predict_fusion(self, text, image=None):
if not self.fusion_model:
return None
encoding = self.tokenizer(
preprocess_text(text),
add_special_tokens=True,
max_length=config.MAX_TEXT_LENGTH,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
pixel_values = None
if image is not None:
try:
image_inputs = self.image_processor(images=image, return_tensors='pt')
pixel_values = image_inputs['pixel_values'].to(self.device)
except:
pass
with torch.no_grad():
outputs = self.fusion_model(input_ids, attention_mask, pixel_values)
probs = torch.softmax(outputs, dim=1)
predicted = torch.argmax(probs, dim=1)
return {
'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE',
'confidence': float(probs[0, predicted.item()].item() * 100),
'spam_probability': float(probs[0, 1].item() * 100),
'ham_probability': float(probs[0, 0].item() * 100)
}
# ================================
# FASTAPI APPLICATION
# ================================
app = FastAPI(
title="FYP4 Spam Detection API",
description="Email spam detection using DeBERTa and ViT models",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize detector (models will be loaded on startup)
detector = None
@app.on_event("startup")
async def startup_event():
"""Load models on startup"""
global detector
text_model_path = os.getenv("TEXT_MODEL_PATH", "models/text_model.pth")
image_model_path = os.getenv("IMAGE_MODEL_PATH", "models/image_model.pth")
fusion_model_path = os.getenv("FUSION_MODEL_PATH", "models/fusion_model.pth")
# Check which models exist
text_exists = os.path.exists(text_model_path)
image_exists = os.path.exists(image_model_path)
fusion_exists = os.path.exists(fusion_model_path)
print(f"Models availability: Text={text_exists}, Image={image_exists}, Fusion={fusion_exists}")
detector = SpamDetector(
text_model_path=text_model_path if text_exists else None,
image_model_path=image_model_path if image_exists else None,
fusion_model_path=fusion_model_path if fusion_exists else None
)
print("API ready!")
# Pydantic models for request/response
class TextRequest(BaseModel):
text: str
class PredictionResponse(BaseModel):
prediction: str
confidence: float
spam_probability: float
ham_probability: float
model_used: str
class PDFPredictionResponse(BaseModel):
email_data: dict
text_result: Optional[dict]
image_result: Optional[dict]
fusion_result: Optional[dict]
final_prediction: str
final_confidence: float
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "FYP4 Spam Detection API",
"version": "1.0.0",
"endpoints": {
"POST /predict/text": "Predict spam from text",
"POST /predict/pdf": "Predict spam from PDF email",
"GET /health": "Health check"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"device": str(config.DEVICE),
"models_loaded": {
"text": detector.text_model is not None if detector else False,
"image": detector.image_model is not None if detector else False,
"fusion": detector.fusion_model is not None if detector else False
}
}
@app.post("/predict/text", response_model=PredictionResponse)
async def predict_text(request: TextRequest):
"""Predict spam from text content"""
if not detector or not detector.text_model:
raise HTTPException(status_code=503, detail="Text model not available")
result = detector.predict_text(request.text)
result['model_used'] = 'text'
return result
@app.post("/predict/pdf", response_model=PDFPredictionResponse)
async def predict_pdf(file: UploadFile = File(...)):
"""Predict spam from PDF email"""
if not file.filename.endswith('.pdf'):
raise HTTPException(status_code=400, detail="File must be a PDF")
if not detector:
raise HTTPException(status_code=503, detail="Models not loaded")
# Read PDF
pdf_bytes = await file.read()
# Extract text and images
email_data = PDFExtractor.extract_text_from_pdf(pdf_bytes)
full_text = f"{email_data['subject']}\n\n{email_data['body']}"
image = PDFExtractor.extract_images_from_pdf(pdf_bytes)
# Get predictions
results = {
'email_data': email_data,
'text_result': None,
'image_result': None,
'fusion_result': None
}
if detector.text_model:
results['text_result'] = detector.predict_text(full_text)
if detector.image_model and image:
results['image_result'] = detector.predict_image(image)
if detector.fusion_model:
results['fusion_result'] = detector.predict_fusion(full_text, image)
# Determine final prediction (prioritize: fusion > text > image)
final_result = results['fusion_result'] or results['text_result'] or results['image_result']
if not final_result:
raise HTTPException(status_code=503, detail="No models available for prediction")
results['final_prediction'] = final_result['prediction']
results['final_confidence'] = final_result['confidence']
return results
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)