| | """ |
| | FYP4 SPAM DETECTION API |
| | FastAPI application for email spam detection using DeBERTa and ViT models |
| | """ |
| |
|
| | import os |
| | import re |
| | import io |
| | import torch |
| | import torch.nn as nn |
| | from fastapi import FastAPI, File, UploadFile, HTTPException |
| | from fastapi.responses import JSONResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel |
| | from typing import Optional |
| | import PyPDF2 |
| | import pdfplumber |
| | from PIL import Image |
| | from transformers import ( |
| | DebertaV2Model, |
| | DebertaV2Tokenizer, |
| | ViTModel, |
| | ViTImageProcessor |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | class Config: |
| | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | TEXT_MODEL = 'microsoft/deberta-v3-base' |
| | IMAGE_MODEL = 'google/vit-base-patch16-224-in21k' |
| | TEXT_HIDDEN_DIM = 768 |
| | IMAGE_HIDDEN_DIM = 768 |
| | FUSION_DIM = 512 |
| | NUM_CLASSES = 2 |
| | DROPOUT = 0.3 |
| | MAX_TEXT_LENGTH = 256 |
| | IMG_SIZE = 224 |
| |
|
| | config = Config() |
| |
|
| | |
| | |
| | |
| |
|
| | class DeBERTaTextEncoder(nn.Module): |
| | def __init__(self, dropout=0.3): |
| | super(DeBERTaTextEncoder, self).__init__() |
| | self.deberta = DebertaV2Model.from_pretrained(config.TEXT_MODEL) |
| | self.projection = nn.Sequential( |
| | nn.Dropout(dropout), |
| | nn.Linear(config.TEXT_HIDDEN_DIM, config.FUSION_DIM), |
| | nn.LayerNorm(config.FUSION_DIM), |
| | nn.GELU() |
| | ) |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) |
| | pooled = outputs.last_hidden_state[:, 0, :] |
| | return self.projection(pooled) |
| |
|
| |
|
| | class ViTImageEncoder(nn.Module): |
| | def __init__(self, dropout=0.3): |
| | super(ViTImageEncoder, self).__init__() |
| | self.vit = ViTModel.from_pretrained(config.IMAGE_MODEL) |
| | self.projection = nn.Sequential( |
| | nn.Dropout(dropout), |
| | nn.Linear(config.IMAGE_HIDDEN_DIM, config.FUSION_DIM), |
| | nn.LayerNorm(config.FUSION_DIM), |
| | nn.GELU() |
| | ) |
| |
|
| | def forward(self, pixel_values): |
| | outputs = self.vit(pixel_values=pixel_values, return_dict=True) |
| | pooled = outputs.last_hidden_state[:, 0, :] |
| | return self.projection(pooled) |
| |
|
| |
|
| | class CrossModalAttention(nn.Module): |
| | def __init__(self, dim=512, num_heads=8, dropout=0.1): |
| | super(CrossModalAttention, self).__init__() |
| | self.cross_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) |
| | self.norm1 = nn.LayerNorm(dim) |
| | self.norm2 = nn.LayerNorm(dim) |
| | self.ffn = nn.Sequential( |
| | nn.Linear(dim, dim * 4), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(dim * 4, dim), |
| | nn.Dropout(dropout) |
| | ) |
| |
|
| | def forward(self, text_features, image_features): |
| | text_features = text_features.unsqueeze(1) |
| | image_features = image_features.unsqueeze(1) |
| | attn_output, _ = self.cross_attn(text_features, image_features, image_features) |
| | fused = self.norm1(text_features + attn_output) |
| | ffn_output = self.ffn(fused) |
| | output = self.norm2(fused + ffn_output) |
| | return output.squeeze(1) |
| |
|
| |
|
| | class TextSpamClassifier(nn.Module): |
| | def __init__(self, dropout=0.3): |
| | super(TextSpamClassifier, self).__init__() |
| | self.text_encoder = DeBERTaTextEncoder(dropout) |
| | self.classifier = nn.Sequential( |
| | nn.Linear(config.FUSION_DIM, 256), |
| | nn.LayerNorm(256), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(256, 128), |
| | nn.LayerNorm(128), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(128, config.NUM_CLASSES) |
| | ) |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | features = self.text_encoder(input_ids, attention_mask) |
| | return self.classifier(features) |
| |
|
| |
|
| | class ImageSpamClassifier(nn.Module): |
| | def __init__(self, dropout=0.3): |
| | super(ImageSpamClassifier, self).__init__() |
| | self.image_encoder = ViTImageEncoder(dropout) |
| | self.classifier = nn.Sequential( |
| | nn.Linear(config.FUSION_DIM, 256), |
| | nn.LayerNorm(256), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(256, 128), |
| | nn.LayerNorm(128), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(128, config.NUM_CLASSES) |
| | ) |
| |
|
| | def forward(self, pixel_values): |
| | features = self.image_encoder(pixel_values) |
| | return self.classifier(features) |
| |
|
| |
|
| | class FusionSpamClassifier(nn.Module): |
| | def __init__(self, dropout=0.3): |
| | super(FusionSpamClassifier, self).__init__() |
| | self.text_encoder = DeBERTaTextEncoder(dropout) |
| | self.image_encoder = ViTImageEncoder(dropout) |
| | self.cross_modal_fusion = CrossModalAttention(config.FUSION_DIM, num_heads=8, dropout=dropout) |
| | self.classifier = nn.Sequential( |
| | nn.Linear(config.FUSION_DIM, 256), |
| | nn.LayerNorm(256), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(256, 128), |
| | nn.LayerNorm(128), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(128, config.NUM_CLASSES) |
| | ) |
| |
|
| | def forward(self, input_ids=None, attention_mask=None, pixel_values=None): |
| | if input_ids is not None and pixel_values is not None: |
| | text_features = self.text_encoder(input_ids, attention_mask) |
| | image_features = self.image_encoder(pixel_values) |
| | fused_features = self.cross_modal_fusion(text_features, image_features) |
| | elif input_ids is not None: |
| | fused_features = self.text_encoder(input_ids, attention_mask) |
| | elif pixel_values is not None: |
| | fused_features = self.image_encoder(pixel_values) |
| | else: |
| | raise ValueError("At least one modality required") |
| | return self.classifier(fused_features) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class PDFExtractor: |
| | @staticmethod |
| | def extract_text_from_pdf(pdf_bytes): |
| | """Extract text from PDF bytes""" |
| | email_data = { |
| | 'subject': '', |
| | 'sender': '', |
| | 'body': '', |
| | 'full_text': '' |
| | } |
| |
|
| | try: |
| | pdf_file = io.BytesIO(pdf_bytes) |
| | with pdfplumber.open(pdf_file) as pdf: |
| | full_text = "" |
| | for page in pdf.pages: |
| | text = page.extract_text() |
| | if text: |
| | full_text += text + "\n" |
| |
|
| | email_data['full_text'] = full_text |
| |
|
| | patterns = { |
| | 'subject': [r'Subject:\s*(.+)', r'SUBJECT:\s*(.+)'], |
| | 'sender': [r'From:\s*(.+)', r'FROM:\s*(.+)'] |
| | } |
| |
|
| | for field, pattern_list in patterns.items(): |
| | for pattern in pattern_list: |
| | match = re.search(pattern, full_text, re.IGNORECASE) |
| | if match: |
| | email_data[field] = match.group(1).strip()[:100] |
| | break |
| |
|
| | body_match = re.search(r'(?:Subject|Date|From|To):.+?\n\n(.+)', full_text, re.DOTALL | re.IGNORECASE) |
| | if body_match: |
| | email_data['body'] = body_match.group(1).strip() |
| | else: |
| | email_data['body'] = full_text |
| |
|
| | return email_data |
| | except Exception as e: |
| | try: |
| | pdf_file = io.BytesIO(pdf_bytes) |
| | pdf_reader = PyPDF2.PdfReader(pdf_file) |
| | full_text = "" |
| | for page in pdf_reader.pages: |
| | text = page.extract_text() |
| | if text: |
| | full_text += text + "\n" |
| | email_data['full_text'] = full_text |
| | email_data['body'] = full_text |
| | return email_data |
| | except Exception as e: |
| | raise HTTPException(status_code=400, detail=f"Error extracting text from PDF: {str(e)}") |
| |
|
| | @staticmethod |
| | def extract_images_from_pdf(pdf_bytes): |
| | """Extract first image from PDF bytes""" |
| | try: |
| | pdf_file = io.BytesIO(pdf_bytes) |
| | pdf_reader = PyPDF2.PdfReader(pdf_file) |
| | |
| | for page_num, page in enumerate(pdf_reader.pages): |
| | if '/XObject' in page['/Resources']: |
| | xObject = page['/Resources']['/XObject'].get_object() |
| | for obj in xObject: |
| | if xObject[obj]['/Subtype'] == '/Image': |
| | try: |
| | size = (xObject[obj]['/Width'], xObject[obj]['/Height']) |
| | data = xObject[obj].get_data() |
| | mode = "RGB" if xObject[obj]['/ColorSpace'] == '/DeviceRGB' else "P" |
| | img = Image.frombytes(mode, size, data) |
| | return img |
| | except: |
| | continue |
| | except: |
| | pass |
| | return None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def preprocess_text(text): |
| | """Preprocess text for model input""" |
| | text = str(text).lower() |
| | text = re.sub(r'http\S+|www\.\S+', '[URL]', text) |
| | text = re.sub(r'\S+@\S+', '[EMAIL]', text) |
| | text = re.sub(r'\d+', '[NUM]', text) |
| | text = re.sub(r'\s+', ' ', text).strip() |
| | return text |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SpamDetector: |
| | def __init__(self, text_model_path=None, image_model_path=None, fusion_model_path=None): |
| | self.device = config.DEVICE |
| | self.tokenizer = DebertaV2Tokenizer.from_pretrained(config.TEXT_MODEL) |
| | self.image_processor = ViTImageProcessor.from_pretrained(config.IMAGE_MODEL) |
| |
|
| | self.text_model = None |
| | self.image_model = None |
| | self.fusion_model = None |
| |
|
| | |
| | if text_model_path and os.path.exists(text_model_path): |
| | print(f"Loading text model from {text_model_path}...") |
| | self.text_model = TextSpamClassifier().to(self.device) |
| | checkpoint = torch.load(text_model_path, map_location=self.device) |
| | self.text_model.load_state_dict(checkpoint['model_state_dict']) |
| | self.text_model.eval() |
| | print("Text model loaded successfully!") |
| |
|
| | if image_model_path and os.path.exists(image_model_path): |
| | print(f"Loading image model from {image_model_path}...") |
| | self.image_model = ImageSpamClassifier().to(self.device) |
| | checkpoint = torch.load(image_model_path, map_location=self.device) |
| | self.image_model.load_state_dict(checkpoint['model_state_dict']) |
| | self.image_model.eval() |
| | print("Image model loaded successfully!") |
| |
|
| | if fusion_model_path and os.path.exists(fusion_model_path): |
| | print(f"Loading fusion model from {fusion_model_path}...") |
| | self.fusion_model = FusionSpamClassifier().to(self.device) |
| | checkpoint = torch.load(fusion_model_path, map_location=self.device) |
| | self.fusion_model.load_state_dict(checkpoint['model_state_dict']) |
| | self.fusion_model.eval() |
| | print("Fusion model loaded successfully!") |
| |
|
| | def predict_text(self, text): |
| | if not self.text_model: |
| | return None |
| |
|
| | encoding = self.tokenizer( |
| | preprocess_text(text), |
| | add_special_tokens=True, |
| | max_length=config.MAX_TEXT_LENGTH, |
| | padding='max_length', |
| | truncation=True, |
| | return_attention_mask=True, |
| | return_tensors='pt' |
| | ) |
| |
|
| | input_ids = encoding['input_ids'].to(self.device) |
| | attention_mask = encoding['attention_mask'].to(self.device) |
| |
|
| | with torch.no_grad(): |
| | outputs = self.text_model(input_ids, attention_mask) |
| | probs = torch.softmax(outputs, dim=1) |
| | predicted = torch.argmax(probs, dim=1) |
| |
|
| | return { |
| | 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', |
| | 'confidence': float(probs[0, predicted.item()].item() * 100), |
| | 'spam_probability': float(probs[0, 1].item() * 100), |
| | 'ham_probability': float(probs[0, 0].item() * 100) |
| | } |
| |
|
| | def predict_image(self, image): |
| | if not self.image_model or image is None: |
| | return None |
| |
|
| | try: |
| | inputs = self.image_processor(images=image, return_tensors='pt') |
| | pixel_values = inputs['pixel_values'].to(self.device) |
| |
|
| | with torch.no_grad(): |
| | outputs = self.image_model(pixel_values) |
| | probs = torch.softmax(outputs, dim=1) |
| | predicted = torch.argmax(probs, dim=1) |
| |
|
| | return { |
| | 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', |
| | 'confidence': float(probs[0, predicted.item()].item() * 100), |
| | 'spam_probability': float(probs[0, 1].item() * 100), |
| | 'ham_probability': float(probs[0, 0].item() * 100) |
| | } |
| | except Exception as e: |
| | return {'error': str(e)} |
| |
|
| | def predict_fusion(self, text, image=None): |
| | if not self.fusion_model: |
| | return None |
| |
|
| | encoding = self.tokenizer( |
| | preprocess_text(text), |
| | add_special_tokens=True, |
| | max_length=config.MAX_TEXT_LENGTH, |
| | padding='max_length', |
| | truncation=True, |
| | return_attention_mask=True, |
| | return_tensors='pt' |
| | ) |
| |
|
| | input_ids = encoding['input_ids'].to(self.device) |
| | attention_mask = encoding['attention_mask'].to(self.device) |
| |
|
| | pixel_values = None |
| | if image is not None: |
| | try: |
| | image_inputs = self.image_processor(images=image, return_tensors='pt') |
| | pixel_values = image_inputs['pixel_values'].to(self.device) |
| | except: |
| | pass |
| |
|
| | with torch.no_grad(): |
| | outputs = self.fusion_model(input_ids, attention_mask, pixel_values) |
| | probs = torch.softmax(outputs, dim=1) |
| | predicted = torch.argmax(probs, dim=1) |
| |
|
| | return { |
| | 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', |
| | 'confidence': float(probs[0, predicted.item()].item() * 100), |
| | 'spam_probability': float(probs[0, 1].item() * 100), |
| | 'ham_probability': float(probs[0, 0].item() * 100) |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | app = FastAPI( |
| | title="FYP4 Spam Detection API", |
| | description="Email spam detection using DeBERTa and ViT models", |
| | version="1.0.0" |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | detector = None |
| |
|
| |
|
| | @app.on_event("startup") |
| | async def startup_event(): |
| | """Load models on startup""" |
| | global detector |
| | |
| | |
| | print(f"Current directory: {os.getcwd()}") |
| | print(f"Files in current directory: {os.listdir('.')}") |
| | |
| | text_model_path = "modelstext_model..pth" |
| | |
| | print(f"Looking for model at: {text_model_path}") |
| | print(f"File exists: {os.path.exists(text_model_path)}") |
| | |
| | |
| | text_exists = os.path.exists(text_model_path) |
| | |
| | print(f"Models availability: Text={text_exists}, Image=False, Fusion=False") |
| | |
| | if text_exists: |
| | print("Loading text model...") |
| | else: |
| | print("ERROR: Text model file not found!") |
| | |
| | detector = SpamDetector( |
| | text_model_path=text_model_path if text_exists else None, |
| | image_model_path=None, |
| | fusion_model_path=None |
| | ) |
| | |
| | print("API ready!") |
| |
|
| |
|
| | |
| | class TextRequest(BaseModel): |
| | text: str |
| |
|
| |
|
| | class PredictionResponse(BaseModel): |
| | prediction: str |
| | confidence: float |
| | spam_probability: float |
| | ham_probability: float |
| | classifier_type: str |
| |
|
| |
|
| | class PDFPredictionResponse(BaseModel): |
| | email_data: dict |
| | text_result: Optional[dict] |
| | image_result: Optional[dict] |
| | fusion_result: Optional[dict] |
| | final_prediction: str |
| | final_confidence: float |
| |
|
| |
|
| | @app.get("/") |
| | async def root(): |
| | """Root endpoint with API information""" |
| | return { |
| | "message": "FYP4 Spam Detection API", |
| | "version": "1.0.0", |
| | "endpoints": { |
| | "POST /predict/text": "Predict spam from text", |
| | "POST /predict/pdf": "Predict spam from PDF email", |
| | "GET /health": "Health check" |
| | } |
| | } |
| |
|
| |
|
| | @app.get("/health") |
| | async def health_check(): |
| | """Health check endpoint""" |
| | return { |
| | "status": "healthy", |
| | "device": str(config.DEVICE), |
| | "models_loaded": { |
| | "text": detector.text_model is not None if detector else False, |
| | "image": detector.image_model is not None if detector else False, |
| | "fusion": detector.fusion_model is not None if detector else False |
| | } |
| | } |
| |
|
| |
|
| | @app.post("/predict/text", response_model=PredictionResponse) |
| | async def predict_text(request: TextRequest): |
| | if not detector or not detector.text_model: |
| | raise HTTPException(status_code=503, detail="Text model not available") |
| | |
| | result = detector.predict_text(request.text) |
| | result['classifier_type'] = 'text' |
| | |
| | return result |
| |
|
| |
|
| | @app.post("/predict/pdf", response_model=PDFPredictionResponse) |
| | async def predict_pdf(file: UploadFile = File(...)): |
| | """Predict spam from PDF email""" |
| | if not file.filename.endswith('.pdf'): |
| | raise HTTPException(status_code=400, detail="File must be a PDF") |
| | |
| | if not detector: |
| | raise HTTPException(status_code=503, detail="Models not loaded") |
| | |
| | |
| | pdf_bytes = await file.read() |
| | |
| | |
| | email_data = PDFExtractor.extract_text_from_pdf(pdf_bytes) |
| | full_text = f"{email_data['subject']}\n\n{email_data['body']}" |
| | image = PDFExtractor.extract_images_from_pdf(pdf_bytes) |
| | |
| | |
| | results = { |
| | 'email_data': email_data, |
| | 'text_result': None, |
| | 'image_result': None, |
| | 'fusion_result': None |
| | } |
| | |
| | if detector.text_model: |
| | results['text_result'] = detector.predict_text(full_text) |
| | |
| | if detector.image_model and image: |
| | results['image_result'] = detector.predict_image(image) |
| | |
| | if detector.fusion_model: |
| | results['fusion_result'] = detector.predict_fusion(full_text, image) |
| | |
| | |
| | final_result = results['fusion_result'] or results['text_result'] or results['image_result'] |
| | |
| | if not final_result: |
| | raise HTTPException(status_code=503, detail="No models available for prediction") |
| | |
| | results['final_prediction'] = final_result['prediction'] |
| | results['final_confidence'] = final_result['confidence'] |
| | |
| | return results |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|