Spaces:
Runtime error
Runtime error
| import pickle | |
| import torch | |
| import gradio as gr | |
| from transformers import BertForSequenceClassification | |
| from huggingface_hub import hf_hub_download | |
| # Load models | |
| # Load traditional models from Hugging Face Hub | |
| variables_path = hf_hub_download( | |
| repo_id="prafulgulani/email_classification_bert", | |
| filename="variables.pkl" | |
| ) | |
| with open(variables_path, 'rb') as f: | |
| naive_model, logistic_model, tfidf_vectorizer, tokenizer = pickle.load(f) | |
| # Load BERT model directly from the Hub | |
| bert_model = BertForSequenceClassification.from_pretrained( | |
| "prafulgulani/email_classification_bert" | |
| ) | |
| bert_model.eval() | |
| def predict_response(text, history): | |
| # Traditional models | |
| tfidf_input = tfidf_vectorizer.transform([text]) | |
| naive_prediction = naive_model.predict(tfidf_input)[0] | |
| logistic_prediction = logistic_model.predict(tfidf_input)[0] | |
| # BERT model | |
| bert_model.eval() | |
| encoded_input = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| output = bert_model(**encoded_input) | |
| bert_prediction = torch.argmax(output.logits, axis=1).item() | |
| result = [] | |
| result.append(f"Naive Bayes: {'Spam' if naive_prediction == 1 else 'Ham'}") | |
| result.append(f"Logistic Regression: {'Spam' if logistic_prediction == 1 else 'Ham'}") | |
| result.append(f"BERT: {'Spam' if bert_prediction == 1 else 'Ham'}") | |
| return "\n".join(result) | |
| chat = gr.ChatInterface( | |
| fn=predict_response, | |
| title="Spam Detector Chat", | |
| description="Send a message and see how different models classify it.", | |
| textbox=gr.Textbox(placeholder="Type an email or message here...", lines=3), | |
| ) | |
| chat.launch(share=True) | |