import pickle import torch import gradio as gr from transformers import BertForSequenceClassification from huggingface_hub import hf_hub_download # Load models # Load traditional models from Hugging Face Hub variables_path = hf_hub_download( repo_id="prafulgulani/email_classification_bert", filename="variables.pkl" ) with open(variables_path, 'rb') as f: naive_model, logistic_model, tfidf_vectorizer, tokenizer = pickle.load(f) # Load BERT model directly from the Hub bert_model = BertForSequenceClassification.from_pretrained( "prafulgulani/email_classification_bert" ) bert_model.eval() def predict_response(text, history): # Traditional models tfidf_input = tfidf_vectorizer.transform([text]) naive_prediction = naive_model.predict(tfidf_input)[0] logistic_prediction = logistic_model.predict(tfidf_input)[0] # BERT model bert_model.eval() encoded_input = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): output = bert_model(**encoded_input) bert_prediction = torch.argmax(output.logits, axis=1).item() result = [] result.append(f"Naive Bayes: {'Spam' if naive_prediction == 1 else 'Ham'}") result.append(f"Logistic Regression: {'Spam' if logistic_prediction == 1 else 'Ham'}") result.append(f"BERT: {'Spam' if bert_prediction == 1 else 'Ham'}") return "\n".join(result) chat = gr.ChatInterface( fn=predict_response, title="Spam Detector Chat", description="Send a message and see how different models classify it.", textbox=gr.Textbox(placeholder="Type an email or message here...", lines=3), ) chat.launch(share=True)