Spaces:
Sleeping
Sleeping
| """ | |
| FYP4 SPAM DETECTION API | |
| FastAPI application for email spam detection using DeBERTa and ViT models | |
| """ | |
| import os | |
| import re | |
| import io | |
| import torch | |
| import torch.nn as nn | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import PyPDF2 | |
| import pdfplumber | |
| from PIL import Image | |
| from transformers import ( | |
| DebertaV2Model, | |
| DebertaV2Tokenizer, | |
| ViTModel, | |
| ViTImageProcessor | |
| ) | |
| # ================================ | |
| # CONFIGURATION | |
| # ================================ | |
| class Config: | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| TEXT_MODEL = 'microsoft/deberta-v3-base' | |
| IMAGE_MODEL = 'google/vit-base-patch16-224-in21k' | |
| TEXT_HIDDEN_DIM = 768 | |
| IMAGE_HIDDEN_DIM = 768 | |
| FUSION_DIM = 512 | |
| NUM_CLASSES = 2 | |
| DROPOUT = 0.3 | |
| MAX_TEXT_LENGTH = 256 | |
| IMG_SIZE = 224 | |
| config = Config() | |
| # ================================ | |
| # MODEL ARCHITECTURES | |
| # ================================ | |
| class DeBERTaTextEncoder(nn.Module): | |
| def __init__(self, dropout=0.3): | |
| super(DeBERTaTextEncoder, self).__init__() | |
| self.deberta = DebertaV2Model.from_pretrained(config.TEXT_MODEL) | |
| self.projection = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(config.TEXT_HIDDEN_DIM, config.FUSION_DIM), | |
| nn.LayerNorm(config.FUSION_DIM), | |
| nn.GELU() | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = outputs.last_hidden_state[:, 0, :] | |
| return self.projection(pooled) | |
| class ViTImageEncoder(nn.Module): | |
| def __init__(self, dropout=0.3): | |
| super(ViTImageEncoder, self).__init__() | |
| self.vit = ViTModel.from_pretrained(config.IMAGE_MODEL) | |
| self.projection = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(config.IMAGE_HIDDEN_DIM, config.FUSION_DIM), | |
| nn.LayerNorm(config.FUSION_DIM), | |
| nn.GELU() | |
| ) | |
| def forward(self, pixel_values): | |
| outputs = self.vit(pixel_values=pixel_values, return_dict=True) | |
| pooled = outputs.last_hidden_state[:, 0, :] | |
| return self.projection(pooled) | |
| class CrossModalAttention(nn.Module): | |
| def __init__(self, dim=512, num_heads=8, dropout=0.1): | |
| super(CrossModalAttention, self).__init__() | |
| self.cross_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, dim * 4), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim * 4, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, text_features, image_features): | |
| text_features = text_features.unsqueeze(1) | |
| image_features = image_features.unsqueeze(1) | |
| attn_output, _ = self.cross_attn(text_features, image_features, image_features) | |
| fused = self.norm1(text_features + attn_output) | |
| ffn_output = self.ffn(fused) | |
| output = self.norm2(fused + ffn_output) | |
| return output.squeeze(1) | |
| class TextSpamClassifier(nn.Module): | |
| def __init__(self, dropout=0.3): | |
| super(TextSpamClassifier, self).__init__() | |
| self.text_encoder = DeBERTaTextEncoder(dropout) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(config.FUSION_DIM, 256), | |
| nn.LayerNorm(256), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 128), | |
| nn.LayerNorm(128), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, config.NUM_CLASSES) | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| features = self.text_encoder(input_ids, attention_mask) | |
| return self.classifier(features) | |
| class ImageSpamClassifier(nn.Module): | |
| def __init__(self, dropout=0.3): | |
| super(ImageSpamClassifier, self).__init__() | |
| self.image_encoder = ViTImageEncoder(dropout) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(config.FUSION_DIM, 256), | |
| nn.LayerNorm(256), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 128), | |
| nn.LayerNorm(128), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, config.NUM_CLASSES) | |
| ) | |
| def forward(self, pixel_values): | |
| features = self.image_encoder(pixel_values) | |
| return self.classifier(features) | |
| class FusionSpamClassifier(nn.Module): | |
| def __init__(self, dropout=0.3): | |
| super(FusionSpamClassifier, self).__init__() | |
| self.text_encoder = DeBERTaTextEncoder(dropout) | |
| self.image_encoder = ViTImageEncoder(dropout) | |
| self.cross_modal_fusion = CrossModalAttention(config.FUSION_DIM, num_heads=8, dropout=dropout) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(config.FUSION_DIM, 256), | |
| nn.LayerNorm(256), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 128), | |
| nn.LayerNorm(128), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, config.NUM_CLASSES) | |
| ) | |
| def forward(self, input_ids=None, attention_mask=None, pixel_values=None): | |
| if input_ids is not None and pixel_values is not None: | |
| text_features = self.text_encoder(input_ids, attention_mask) | |
| image_features = self.image_encoder(pixel_values) | |
| fused_features = self.cross_modal_fusion(text_features, image_features) | |
| elif input_ids is not None: | |
| fused_features = self.text_encoder(input_ids, attention_mask) | |
| elif pixel_values is not None: | |
| fused_features = self.image_encoder(pixel_values) | |
| else: | |
| raise ValueError("At least one modality required") | |
| return self.classifier(fused_features) | |
| # ================================ | |
| # PDF EXTRACTION | |
| # ================================ | |
| class PDFExtractor: | |
| def extract_text_from_pdf(pdf_bytes): | |
| """Extract text from PDF bytes""" | |
| email_data = { | |
| 'subject': '', | |
| 'sender': '', | |
| 'body': '', | |
| 'full_text': '' | |
| } | |
| try: | |
| pdf_file = io.BytesIO(pdf_bytes) | |
| with pdfplumber.open(pdf_file) as pdf: | |
| full_text = "" | |
| for page in pdf.pages: | |
| text = page.extract_text() | |
| if text: | |
| full_text += text + "\n" | |
| email_data['full_text'] = full_text | |
| patterns = { | |
| 'subject': [r'Subject:\s*(.+)', r'SUBJECT:\s*(.+)'], | |
| 'sender': [r'From:\s*(.+)', r'FROM:\s*(.+)'] | |
| } | |
| for field, pattern_list in patterns.items(): | |
| for pattern in pattern_list: | |
| match = re.search(pattern, full_text, re.IGNORECASE) | |
| if match: | |
| email_data[field] = match.group(1).strip()[:100] | |
| break | |
| body_match = re.search(r'(?:Subject|Date|From|To):.+?\n\n(.+)', full_text, re.DOTALL | re.IGNORECASE) | |
| if body_match: | |
| email_data['body'] = body_match.group(1).strip() | |
| else: | |
| email_data['body'] = full_text | |
| return email_data | |
| except Exception as e: | |
| try: | |
| pdf_file = io.BytesIO(pdf_bytes) | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| full_text = "" | |
| for page in pdf_reader.pages: | |
| text = page.extract_text() | |
| if text: | |
| full_text += text + "\n" | |
| email_data['full_text'] = full_text | |
| email_data['body'] = full_text | |
| return email_data | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error extracting text from PDF: {str(e)}") | |
| def extract_images_from_pdf(pdf_bytes): | |
| """Extract first image from PDF bytes""" | |
| try: | |
| pdf_file = io.BytesIO(pdf_bytes) | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| for page_num, page in enumerate(pdf_reader.pages): | |
| if '/XObject' in page['/Resources']: | |
| xObject = page['/Resources']['/XObject'].get_object() | |
| for obj in xObject: | |
| if xObject[obj]['/Subtype'] == '/Image': | |
| try: | |
| size = (xObject[obj]['/Width'], xObject[obj]['/Height']) | |
| data = xObject[obj].get_data() | |
| mode = "RGB" if xObject[obj]['/ColorSpace'] == '/DeviceRGB' else "P" | |
| img = Image.frombytes(mode, size, data) | |
| return img | |
| except: | |
| continue | |
| except: | |
| pass | |
| return None | |
| # ================================ | |
| # TEXT PREPROCESSING | |
| # ================================ | |
| def preprocess_text(text): | |
| """Preprocess text for model input""" | |
| text = str(text).lower() | |
| text = re.sub(r'http\S+|www\.\S+', '[URL]', text) | |
| text = re.sub(r'\S+@\S+', '[EMAIL]', text) | |
| text = re.sub(r'\d+', '[NUM]', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| # ================================ | |
| # SPAM DETECTOR | |
| # ================================ | |
| class SpamDetector: | |
| def __init__(self, text_model_path=None, image_model_path=None, fusion_model_path=None): | |
| self.device = config.DEVICE | |
| self.tokenizer = DebertaV2Tokenizer.from_pretrained(config.TEXT_MODEL) | |
| self.image_processor = ViTImageProcessor.from_pretrained(config.IMAGE_MODEL) | |
| self.text_model = None | |
| self.image_model = None | |
| self.fusion_model = None | |
| # Load models | |
| if text_model_path and os.path.exists(text_model_path): | |
| print(f"Loading text model from {text_model_path}...") | |
| self.text_model = TextSpamClassifier().to(self.device) | |
| checkpoint = torch.load(text_model_path, map_location=self.device) | |
| self.text_model.load_state_dict(checkpoint['model_state_dict']) | |
| self.text_model.eval() | |
| print("Text model loaded successfully!") | |
| if image_model_path and os.path.exists(image_model_path): | |
| print(f"Loading image model from {image_model_path}...") | |
| self.image_model = ImageSpamClassifier().to(self.device) | |
| checkpoint = torch.load(image_model_path, map_location=self.device) | |
| self.image_model.load_state_dict(checkpoint['model_state_dict']) | |
| self.image_model.eval() | |
| print("Image model loaded successfully!") | |
| if fusion_model_path and os.path.exists(fusion_model_path): | |
| print(f"Loading fusion model from {fusion_model_path}...") | |
| self.fusion_model = FusionSpamClassifier().to(self.device) | |
| checkpoint = torch.load(fusion_model_path, map_location=self.device) | |
| self.fusion_model.load_state_dict(checkpoint['model_state_dict']) | |
| self.fusion_model.eval() | |
| print("Fusion model loaded successfully!") | |
| def predict_text(self, text): | |
| if not self.text_model: | |
| return None | |
| encoding = self.tokenizer( | |
| preprocess_text(text), | |
| add_special_tokens=True, | |
| max_length=config.MAX_TEXT_LENGTH, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(self.device) | |
| attention_mask = encoding['attention_mask'].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.text_model(input_ids, attention_mask) | |
| probs = torch.softmax(outputs, dim=1) | |
| predicted = torch.argmax(probs, dim=1) | |
| return { | |
| 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', | |
| 'confidence': float(probs[0, predicted.item()].item() * 100), | |
| 'spam_probability': float(probs[0, 1].item() * 100), | |
| 'ham_probability': float(probs[0, 0].item() * 100) | |
| } | |
| def predict_image(self, image): | |
| if not self.image_model or image is None: | |
| return None | |
| try: | |
| inputs = self.image_processor(images=image, return_tensors='pt') | |
| pixel_values = inputs['pixel_values'].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.image_model(pixel_values) | |
| probs = torch.softmax(outputs, dim=1) | |
| predicted = torch.argmax(probs, dim=1) | |
| return { | |
| 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', | |
| 'confidence': float(probs[0, predicted.item()].item() * 100), | |
| 'spam_probability': float(probs[0, 1].item() * 100), | |
| 'ham_probability': float(probs[0, 0].item() * 100) | |
| } | |
| except Exception as e: | |
| return {'error': str(e)} | |
| def predict_fusion(self, text, image=None): | |
| if not self.fusion_model: | |
| return None | |
| encoding = self.tokenizer( | |
| preprocess_text(text), | |
| add_special_tokens=True, | |
| max_length=config.MAX_TEXT_LENGTH, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(self.device) | |
| attention_mask = encoding['attention_mask'].to(self.device) | |
| pixel_values = None | |
| if image is not None: | |
| try: | |
| image_inputs = self.image_processor(images=image, return_tensors='pt') | |
| pixel_values = image_inputs['pixel_values'].to(self.device) | |
| except: | |
| pass | |
| with torch.no_grad(): | |
| outputs = self.fusion_model(input_ids, attention_mask, pixel_values) | |
| probs = torch.softmax(outputs, dim=1) | |
| predicted = torch.argmax(probs, dim=1) | |
| return { | |
| 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE', | |
| 'confidence': float(probs[0, predicted.item()].item() * 100), | |
| 'spam_probability': float(probs[0, 1].item() * 100), | |
| 'ham_probability': float(probs[0, 0].item() * 100) | |
| } | |
| # ================================ | |
| # FASTAPI APPLICATION | |
| # ================================ | |
| app = FastAPI( | |
| title="FYP4 Spam Detection API", | |
| description="Email spam detection using DeBERTa and ViT models", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize detector (models will be loaded on startup) | |
| detector = None | |
| async def startup_event(): | |
| """Load models on startup""" | |
| global detector | |
| text_model_path = os.getenv("TEXT_MODEL_PATH", "models/text_model.pth") | |
| image_model_path = os.getenv("IMAGE_MODEL_PATH", "models/image_model.pth") | |
| fusion_model_path = os.getenv("FUSION_MODEL_PATH", "models/fusion_model.pth") | |
| # Check which models exist | |
| text_exists = os.path.exists(text_model_path) | |
| image_exists = os.path.exists(image_model_path) | |
| fusion_exists = os.path.exists(fusion_model_path) | |
| print(f"Models availability: Text={text_exists}, Image={image_exists}, Fusion={fusion_exists}") | |
| detector = SpamDetector( | |
| text_model_path=text_model_path if text_exists else None, | |
| image_model_path=image_model_path if image_exists else None, | |
| fusion_model_path=fusion_model_path if fusion_exists else None | |
| ) | |
| print("API ready!") | |
| # Pydantic models for request/response | |
| class TextRequest(BaseModel): | |
| text: str | |
| class PredictionResponse(BaseModel): | |
| prediction: str | |
| confidence: float | |
| spam_probability: float | |
| ham_probability: float | |
| model_used: str | |
| class PDFPredictionResponse(BaseModel): | |
| email_data: dict | |
| text_result: Optional[dict] | |
| image_result: Optional[dict] | |
| fusion_result: Optional[dict] | |
| final_prediction: str | |
| final_confidence: float | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "FYP4 Spam Detection API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "POST /predict/text": "Predict spam from text", | |
| "POST /predict/pdf": "Predict spam from PDF email", | |
| "GET /health": "Health check" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "device": str(config.DEVICE), | |
| "models_loaded": { | |
| "text": detector.text_model is not None if detector else False, | |
| "image": detector.image_model is not None if detector else False, | |
| "fusion": detector.fusion_model is not None if detector else False | |
| } | |
| } | |
| async def predict_text(request: TextRequest): | |
| """Predict spam from text content""" | |
| if not detector or not detector.text_model: | |
| raise HTTPException(status_code=503, detail="Text model not available") | |
| result = detector.predict_text(request.text) | |
| result['model_used'] = 'text' | |
| return result | |
| async def predict_pdf(file: UploadFile = File(...)): | |
| """Predict spam from PDF email""" | |
| if not file.filename.endswith('.pdf'): | |
| raise HTTPException(status_code=400, detail="File must be a PDF") | |
| if not detector: | |
| raise HTTPException(status_code=503, detail="Models not loaded") | |
| # Read PDF | |
| pdf_bytes = await file.read() | |
| # Extract text and images | |
| email_data = PDFExtractor.extract_text_from_pdf(pdf_bytes) | |
| full_text = f"{email_data['subject']}\n\n{email_data['body']}" | |
| image = PDFExtractor.extract_images_from_pdf(pdf_bytes) | |
| # Get predictions | |
| results = { | |
| 'email_data': email_data, | |
| 'text_result': None, | |
| 'image_result': None, | |
| 'fusion_result': None | |
| } | |
| if detector.text_model: | |
| results['text_result'] = detector.predict_text(full_text) | |
| if detector.image_model and image: | |
| results['image_result'] = detector.predict_image(image) | |
| if detector.fusion_model: | |
| results['fusion_result'] = detector.predict_fusion(full_text, image) | |
| # Determine final prediction (prioritize: fusion > text > image) | |
| final_result = results['fusion_result'] or results['text_result'] or results['image_result'] | |
| if not final_result: | |
| raise HTTPException(status_code=503, detail="No models available for prediction") | |
| results['final_prediction'] = final_result['prediction'] | |
| results['final_confidence'] = final_result['confidence'] | |
| return results | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |