""" FYP4 SPAM DETECTION API FastAPI application for email spam detection using DeBERTa and ViT models """ import os import re import io import torch import torch.nn as nn from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional import PyPDF2 import pdfplumber from PIL import Image from transformers import ( DebertaV2Model, DebertaV2Tokenizer, ViTModel, ViTImageProcessor ) # ================================ # CONFIGURATION # ================================ class Config: DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') TEXT_MODEL = 'microsoft/deberta-v3-base' IMAGE_MODEL = 'google/vit-base-patch16-224-in21k' TEXT_HIDDEN_DIM = 768 IMAGE_HIDDEN_DIM = 768 FUSION_DIM = 512 NUM_CLASSES = 2 DROPOUT = 0.3 MAX_TEXT_LENGTH = 256 IMG_SIZE = 224 config = Config() # ================================ # MODEL ARCHITECTURES # ================================ class DeBERTaTextEncoder(nn.Module): def __init__(self, dropout=0.3): super(DeBERTaTextEncoder, self).__init__() self.deberta = DebertaV2Model.from_pretrained(config.TEXT_MODEL) self.projection = nn.Sequential( nn.Dropout(dropout), nn.Linear(config.TEXT_HIDDEN_DIM, config.FUSION_DIM), nn.LayerNorm(config.FUSION_DIM), nn.GELU() ) def forward(self, input_ids, attention_mask): outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state[:, 0, :] return self.projection(pooled) class ViTImageEncoder(nn.Module): def __init__(self, dropout=0.3): super(ViTImageEncoder, self).__init__() self.vit = ViTModel.from_pretrained(config.IMAGE_MODEL) self.projection = nn.Sequential( nn.Dropout(dropout), nn.Linear(config.IMAGE_HIDDEN_DIM, config.FUSION_DIM), nn.LayerNorm(config.FUSION_DIM), nn.GELU() ) def forward(self, pixel_values): outputs = self.vit(pixel_values=pixel_values, return_dict=True) pooled = outputs.last_hidden_state[:, 0, :] return self.projection(pooled) class CrossModalAttention(nn.Module): def __init__(self, dim=512, num_heads=8, dropout=0.1): super(CrossModalAttention, self).__init__() self.cross_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.ffn = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim), nn.Dropout(dropout) ) def forward(self, text_features, image_features): text_features = text_features.unsqueeze(1) image_features = image_features.unsqueeze(1) attn_output, _ = self.cross_attn(text_features, image_features, image_features) fused = self.norm1(text_features + attn_output) ffn_output = self.ffn(fused) output = self.norm2(fused + ffn_output) return output.squeeze(1) class TextSpamClassifier(nn.Module): def __init__(self, dropout=0.3): super(TextSpamClassifier, self).__init__() self.text_encoder = DeBERTaTextEncoder(dropout) self.classifier = nn.Sequential( nn.Linear(config.FUSION_DIM, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, config.NUM_CLASSES) ) def forward(self, input_ids, attention_mask): features = self.text_encoder(input_ids, attention_mask) return self.classifier(features) class ImageSpamClassifier(nn.Module): def __init__(self, dropout=0.3): super(ImageSpamClassifier, self).__init__() self.image_encoder = ViTImageEncoder(dropout) self.classifier = nn.Sequential( nn.Linear(config.FUSION_DIM, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, config.NUM_CLASSES) ) def forward(self, pixel_values): features = self.image_encoder(pixel_values) return self.classifier(features) class FusionSpamClassifier(nn.Module): def __init__(self, dropout=0.3): super(FusionSpamClassifier, self).__init__() self.text_encoder = DeBERTaTextEncoder(dropout) self.image_encoder = ViTImageEncoder(dropout) self.cross_modal_fusion = CrossModalAttention(config.FUSION_DIM, num_heads=8, dropout=dropout) self.classifier = nn.Sequential( nn.Linear(config.FUSION_DIM, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, config.NUM_CLASSES) ) def forward(self, input_ids=None, attention_mask=None, pixel_values=None): if input_ids is not None and pixel_values is not None: text_features = self.text_encoder(input_ids, attention_mask) image_features = self.image_encoder(pixel_values) fused_features = self.cross_modal_fusion(text_features, image_features) elif input_ids is not None: fused_features = self.text_encoder(input_ids, attention_mask) elif pixel_values is not None: fused_features = self.image_encoder(pixel_values) else: raise ValueError("At least one modality required") return self.classifier(fused_features) # ================================ # PDF EXTRACTION # ================================ class PDFExtractor: @staticmethod def extract_text_from_pdf(pdf_bytes): """Extract text from PDF bytes""" email_data = { 'subject': '', 'sender': '', 'body': '', 'full_text': '' } try: pdf_file = io.BytesIO(pdf_bytes) with pdfplumber.open(pdf_file) as pdf: full_text = "" for page in pdf.pages: text = page.extract_text() if text: full_text += text + "\n" email_data['full_text'] = full_text patterns = { 'subject': [r'Subject:\s*(.+)', r'SUBJECT:\s*(.+)'], 'sender': [r'From:\s*(.+)', r'FROM:\s*(.+)'] } for field, pattern_list in patterns.items(): for pattern in pattern_list: match = re.search(pattern, full_text, re.IGNORECASE) if match: email_data[field] = match.group(1).strip()[:100] break body_match = re.search(r'(?:Subject|Date|From|To):.+?\n\n(.+)', full_text, re.DOTALL | re.IGNORECASE) if body_match: email_data['body'] = body_match.group(1).strip() else: email_data['body'] = full_text return email_data except Exception as e: try: pdf_file = io.BytesIO(pdf_bytes) pdf_reader = PyPDF2.PdfReader(pdf_file) full_text = "" for page in pdf_reader.pages: text = page.extract_text() if text: full_text += text + "\n" email_data['full_text'] = full_text email_data['body'] = full_text return email_data except Exception as e: raise HTTPException(status_code=400, detail=f"Error extracting text from PDF: {str(e)}") @staticmethod def extract_images_from_pdf(pdf_bytes): """Extract first image from PDF bytes""" try: pdf_file = io.BytesIO(pdf_bytes) pdf_reader = PyPDF2.PdfReader(pdf_file) for page_num, page in enumerate(pdf_reader.pages): if '/XObject' in page['/Resources']: xObject = page['/Resources']['/XObject'].get_object() for obj in xObject: if xObject[obj]['/Subtype'] == '/Image': try: size = (xObject[obj]['/Width'], xObject[obj]['/Height']) data = xObject[obj].get_data() mode = "RGB" if xObject[obj]['/ColorSpace'] == '/DeviceRGB' else "P" img = Image.frombytes(mode, size, data) return img except: continue except: pass return None # ================================ # TEXT PREPROCESSING # ================================ def preprocess_text(text): """Preprocess text for model input""" text = str(text).lower() text = re.sub(r'http\S+|www\.\S+', '[URL]', text) text = re.sub(r'\S+@\S+', '[EMAIL]', text) text = re.sub(r'\d+', '[NUM]', text) text = re.sub(r'\s+', ' ', text).strip() return text # ================================ # SPAM DETECTOR # ================================ class SpamDetector: def __init__(self, text_model_path=None, image_model_path=None, fusion_model_path=None): self.device = config.DEVICE self.tokenizer = DebertaV2Tokenizer.from_pretrained(config.TEXT_MODEL) self.image_processor = ViTImageProcessor.from_pretrained(config.IMAGE_MODEL) self.text_model = None self.image_model = None self.fusion_model = None # Load models if text_model_path and os.path.exists(text_model_path): print(f"Loading text model from {text_model_path}...") self.text_model = TextSpamClassifier().to(self.device) checkpoint = torch.load(text_model_path, map_location=self.device) self.text_model.load_state_dict(checkpoint['model_state_dict']) self.text_model.eval() print("Text model loaded successfully!") if image_model_path and os.path.exists(image_model_path): print(f"Loading image model from {image_model_path}...") self.image_model = ImageSpamClassifier().to(self.device) checkpoint = torch.load(image_model_path, map_location=self.device) self.image_model.load_state_dict(checkpoint['model_state_dict']) self.image_model.eval() print("Image model loaded successfully!") if fusion_model_path and os.path.exists(fusion_model_path): print(f"Loading fusion model from {fusion_model_path}...") self.fusion_model = FusionSpamClassifier().to(self.device) checkpoint = torch.load(fusion_model_path, map_location=self.device) self.fusion_model.load_state_dict(checkpoint['model_state_dict']) self.fusion_model.eval() print("Fusion model loaded successfully!") def predict_text(self, text): if not self.text_model: return None encoding = self.tokenizer( preprocess_text(text), add_special_tokens=True, max_length=config.MAX_TEXT_LENGTH, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(self.device) attention_mask = encoding['attention_mask'].to(self.device) with torch.no_grad(): outputs = self.text_model(input_ids, attention_mask) probs = torch.softmax(outputs, dim=1) predicted = torch.argmax(probs, dim=1) return { 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', 'confidence': float(probs[0, predicted.item()].item() * 100), 'spam_probability': float(probs[0, 1].item() * 100), 'ham_probability': float(probs[0, 0].item() * 100) } def predict_image(self, image): if not self.image_model or image is None: return None try: inputs = self.image_processor(images=image, return_tensors='pt') pixel_values = inputs['pixel_values'].to(self.device) with torch.no_grad(): outputs = self.image_model(pixel_values) probs = torch.softmax(outputs, dim=1) predicted = torch.argmax(probs, dim=1) return { 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', 'confidence': float(probs[0, predicted.item()].item() * 100), 'spam_probability': float(probs[0, 1].item() * 100), 'ham_probability': float(probs[0, 0].item() * 100) } except Exception as e: return {'error': str(e)} def predict_fusion(self, text, image=None): if not self.fusion_model: return None encoding = self.tokenizer( preprocess_text(text), add_special_tokens=True, max_length=config.MAX_TEXT_LENGTH, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(self.device) attention_mask = encoding['attention_mask'].to(self.device) pixel_values = None if image is not None: try: image_inputs = self.image_processor(images=image, return_tensors='pt') pixel_values = image_inputs['pixel_values'].to(self.device) except: pass with torch.no_grad(): outputs = self.fusion_model(input_ids, attention_mask, pixel_values) probs = torch.softmax(outputs, dim=1) predicted = torch.argmax(probs, dim=1) return { 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', 'confidence': float(probs[0, predicted.item()].item() * 100), 'spam_probability': float(probs[0, 1].item() * 100), 'ham_probability': float(probs[0, 0].item() * 100) } # ================================ # FASTAPI APPLICATION # ================================ app = FastAPI( title="FYP4 Spam Detection API", description="Email spam detection using DeBERTa and ViT models", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize detector (models will be loaded on startup) detector = None @app.on_event("startup") async def startup_event(): """Load models on startup""" global detector text_model_path = os.getenv("TEXT_MODEL_PATH", "models/text_model.pth") image_model_path = os.getenv("IMAGE_MODEL_PATH", "models/image_model.pth") fusion_model_path = os.getenv("FUSION_MODEL_PATH", "models/fusion_model.pth") # Check which models exist text_exists = os.path.exists(text_model_path) image_exists = os.path.exists(image_model_path) fusion_exists = os.path.exists(fusion_model_path) print(f"Models availability: Text={text_exists}, Image={image_exists}, Fusion={fusion_exists}") detector = SpamDetector( text_model_path=text_model_path if text_exists else None, image_model_path=image_model_path if image_exists else None, fusion_model_path=fusion_model_path if fusion_exists else None ) print("API ready!") # Pydantic models for request/response class TextRequest(BaseModel): text: str class PredictionResponse(BaseModel): prediction: str confidence: float spam_probability: float ham_probability: float model_used: str class PDFPredictionResponse(BaseModel): email_data: dict text_result: Optional[dict] image_result: Optional[dict] fusion_result: Optional[dict] final_prediction: str final_confidence: float @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "FYP4 Spam Detection API", "version": "1.0.0", "endpoints": { "POST /predict/text": "Predict spam from text", "POST /predict/pdf": "Predict spam from PDF email", "GET /health": "Health check" } } @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "device": str(config.DEVICE), "models_loaded": { "text": detector.text_model is not None if detector else False, "image": detector.image_model is not None if detector else False, "fusion": detector.fusion_model is not None if detector else False } } @app.post("/predict/text", response_model=PredictionResponse) async def predict_text(request: TextRequest): """Predict spam from text content""" if not detector or not detector.text_model: raise HTTPException(status_code=503, detail="Text model not available") result = detector.predict_text(request.text) result['model_used'] = 'text' return result @app.post("/predict/pdf", response_model=PDFPredictionResponse) async def predict_pdf(file: UploadFile = File(...)): """Predict spam from PDF email""" if not file.filename.endswith('.pdf'): raise HTTPException(status_code=400, detail="File must be a PDF") if not detector: raise HTTPException(status_code=503, detail="Models not loaded") # Read PDF pdf_bytes = await file.read() # Extract text and images email_data = PDFExtractor.extract_text_from_pdf(pdf_bytes) full_text = f"{email_data['subject']}\n\n{email_data['body']}" image = PDFExtractor.extract_images_from_pdf(pdf_bytes) # Get predictions results = { 'email_data': email_data, 'text_result': None, 'image_result': None, 'fusion_result': None } if detector.text_model: results['text_result'] = detector.predict_text(full_text) if detector.image_model and image: results['image_result'] = detector.predict_image(image) if detector.fusion_model: results['fusion_result'] = detector.predict_fusion(full_text, image) # Determine final prediction (prioritize: fusion > text > image) final_result = results['fusion_result'] or results['text_result'] or results['image_result'] if not final_result: raise HTTPException(status_code=503, detail="No models available for prediction") results['final_prediction'] = final_result['prediction'] results['final_confidence'] = final_result['confidence'] return results if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)