prafulgulani555
final changes
1026718
import pickle
import torch
import gradio as gr
from transformers import BertForSequenceClassification
from huggingface_hub import hf_hub_download
# Load models
# Load traditional models from Hugging Face Hub
variables_path = hf_hub_download(
repo_id="prafulgulani/email_classification_bert",
filename="variables.pkl"
)
with open(variables_path, 'rb') as f:
naive_model, logistic_model, tfidf_vectorizer, tokenizer = pickle.load(f)
# Load BERT model directly from the Hub
bert_model = BertForSequenceClassification.from_pretrained(
"prafulgulani/email_classification_bert"
)
bert_model.eval()
def predict_response(text, history):
# Traditional models
tfidf_input = tfidf_vectorizer.transform([text])
naive_prediction = naive_model.predict(tfidf_input)[0]
logistic_prediction = logistic_model.predict(tfidf_input)[0]
# BERT model
bert_model.eval()
encoded_input = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
output = bert_model(**encoded_input)
bert_prediction = torch.argmax(output.logits, axis=1).item()
result = []
result.append(f"Naive Bayes: {'Spam' if naive_prediction == 1 else 'Ham'}")
result.append(f"Logistic Regression: {'Spam' if logistic_prediction == 1 else 'Ham'}")
result.append(f"BERT: {'Spam' if bert_prediction == 1 else 'Ham'}")
return "\n".join(result)
chat = gr.ChatInterface(
fn=predict_response,
title="Spam Detector Chat",
description="Send a message and see how different models classify it.",
textbox=gr.Textbox(placeholder="Type an email or message here...", lines=3),
)
chat.launch(share=True)