File size: 2,151 Bytes
2afe0ff
 
 
9cd03b7
2afe0ff
4878e99
2afe0ff
 
 
 
 
 
 
 
 
 
0b77b61
2afe0ff
 
 
 
 
 
 
 
 
 
 
 
0b77b61
2afe0ff
0b77b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d9bc8
2afe0ff
 
 
 
 
c9d9bc8
2afe0ff
 
c9d9bc8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re
# Load the fine-tuned model and tokenizer
MODEL_PATH = "./model"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Function to predict spam
def predict_spam(text):
    # Tokenize input text
    text = preprocess(text)
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Convert logits to label
    logits = outputs.logits
    predicted_class = torch.argmax(logits, dim=1).item()
    
    # Interpret the prediction
    return "Phishing Mail" if predicted_class == 1 else "Not Phishing Mail"

def preprocess(text: str) -> str:
    """
    Cleans and formats the given text by:
    - Removing excessive spaces and blank lines
    - Ensuring proper paragraph structure

    :param text: Input raw text with excessive spaces
    :return: Cleaned and formatted text
    """
    text = re.sub(r'\n\s*\n+', '\n\n', text.strip()) 
    text = re.sub(r'\s+', ' ', text) 
    text = text.replace(" .", ".").replace(" ,", ",")  
    sections = re.split(r'(r/\w+:\s*)', text)  
    formatted_text = ""
    for section in sections:
        if section.startswith("r/"):  
            formatted_text += f"\n\n**{section.strip()}**\n\n"
        else:  
            formatted_text += section.strip() + "\n\n"
    return formatted_text.strip()
    
# Gradio Interface for Web App
iface = gr.Interface(
    fn=predict_spam,
    inputs=gr.Textbox(label="Enter Email Content"),
    outputs=gr.Label(label="Spam Detection Result"),
    title="📧 Spam Email Detector",
    description="Enter an email body to detect whether it's spam or not."
)

# Enable the API functionality for POST requests
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)