File size: 1,694 Bytes
90b561e
 
 
 
1026718
90b561e
 
1026718
 
 
 
 
 
 
90b561e
 
1026718
 
 
 
90b561e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pickle
import torch
import gradio as gr
from transformers import BertForSequenceClassification
from huggingface_hub import hf_hub_download

# Load models
# Load traditional models from Hugging Face Hub
variables_path = hf_hub_download(
    repo_id="prafulgulani/email_classification_bert",
    filename="variables.pkl"
)

with open(variables_path, 'rb') as f:
    naive_model, logistic_model, tfidf_vectorizer, tokenizer = pickle.load(f)

# Load BERT model directly from the Hub
bert_model = BertForSequenceClassification.from_pretrained(
    "prafulgulani/email_classification_bert"
)
bert_model.eval()

def predict_response(text, history):  
    # Traditional models
    tfidf_input = tfidf_vectorizer.transform([text])
    naive_prediction = naive_model.predict(tfidf_input)[0]
    logistic_prediction = logistic_model.predict(tfidf_input)[0]

    # BERT model
    bert_model.eval()
    encoded_input = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        output = bert_model(**encoded_input)
        bert_prediction = torch.argmax(output.logits, axis=1).item()

    result = []
    result.append(f"Naive Bayes: {'Spam' if naive_prediction == 1 else 'Ham'}")
    result.append(f"Logistic Regression: {'Spam' if logistic_prediction == 1 else 'Ham'}")
    result.append(f"BERT: {'Spam' if bert_prediction == 1 else 'Ham'}")

    return "\n".join(result)

chat = gr.ChatInterface(
    fn=predict_response,
    title="Spam Detector Chat",
    description="Send a message and see how different models classify it.",
    textbox=gr.Textbox(placeholder="Type an email or message here...", lines=3),
)

chat.launch(share=True)