import gradio as gr import torch import torch.nn as nn import joblib # ---------- Model Definition (Fixed Architecture) ---------- class SpamClassifier(nn.Module): def __init__(self, input_dim): super(SpamClassifier, self).__init__() self.model = nn.Sequential( nn.Linear(input_dim, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU(), # CHANGED: Output layer is 1, not 2 nn.Linear(32, 1), # CHANGED: Use Sigmoid for binary output (0 to 1) instead of Softmax nn.Sigmoid() ) def forward(self, x): return self.model(x) # ---------- Load Vectorizer ---------- # Ensure you have scikit-learn==1.2.2 in requirements.txt if you get warnings vectorizer = joblib.load("model/vectorizer.pkl") input_dim = len(vectorizer.get_feature_names_out()) # ---------- Load Model ---------- model = SpamClassifier(input_dim) model.load_state_dict(torch.load("model/email_spam_classifier.pth", map_location=torch.device("cpu"))) model.eval() # ---------- Prediction Function ---------- def predict_email(text): X = vectorizer.transform([text]).toarray() X_tensor = torch.tensor(X, dtype=torch.float32) with torch.no_grad(): # Get the single probability value (0 = Ham, 1 = Spam) prob_spam = model(X_tensor).item() # Manually calculate Ham probability as the complement of Spam return {"Safe": 1 - prob_spam, "Spam": prob_spam} # ---------- Gradio Interface ---------- iface = gr.Interface( fn=predict_email, inputs=gr.Textbox(lines=5, placeholder="Paste your email here..."), outputs=gr.Label(num_top_classes=2), title="Email Spam Classifier", description="Classify emails as Spam or Safe using a PyTorch model." ) if __name__ == "__main__": iface.launch()