Pant0x commited on
Commit
92eb32f
·
verified ·
1 Parent(s): 92ba521

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -52
app.py CHANGED
@@ -1,52 +1,51 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- import pickle
5
-
6
- # ---------- Model Definition ----------
7
- class SpamClassifier(nn.Module):
8
- def __init__(self, input_dim):
9
- super(SpamClassifier, self).__init__()
10
- self.fc1 = nn.Linear(input_dim, 128)
11
- self.relu = nn.ReLU()
12
- self.fc2 = nn.Linear(128, 2)
13
- self.softmax = nn.Softmax(dim=1)
14
-
15
- def forward(self, x):
16
- x = self.fc1(x)
17
- x = self.relu(x)
18
- x = self.fc2(x)
19
- x = self.softmax(x)
20
- return x
21
-
22
- # ---------- Load Vectorizer ----------
23
- with open("model/vectorizer.pkl", "rb") as f:
24
- vectorizer = pickle.load(f)
25
-
26
- input_dim = len(vectorizer.get_feature_names_out())
27
-
28
- # ---------- Load Model ----------
29
- model = SpamClassifier(input_dim)
30
- model.load_state_dict(torch.load("model/email_spam_classifier.pth", map_location=torch.device("cpu")))
31
- model.eval()
32
-
33
- # ---------- Prediction Function ----------
34
- def predict_email(text):
35
- X = vectorizer.transform([text]).toarray()
36
- X_tensor = torch.tensor(X, dtype=torch.float32)
37
- with torch.no_grad():
38
- probs = model(X_tensor).numpy()[0]
39
- labels = ["Ham", "Spam"]
40
- return {labels[i]: float(probs[i]) for i in range(2)}
41
-
42
- # ---------- Gradio Interface ----------
43
- iface = gr.Interface(
44
- fn=predict_email,
45
- inputs=gr.Textbox(lines=5, placeholder="Paste your email here..."),
46
- outputs=gr.Label(num_top_classes=2),
47
- title="Email Spam Classifier",
48
- description="Classify emails as Spam or Ham using a PyTorch model."
49
- )
50
-
51
- if __name__ == "__main__":
52
- iface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import joblib
5
+
6
+ # ---------- Model Definition ----------
7
+ class SpamClassifier(nn.Module):
8
+ def __init__(self, input_dim):
9
+ super(SpamClassifier, self).__init__()
10
+ self.fc1 = nn.Linear(input_dim, 128)
11
+ self.relu = nn.ReLU()
12
+ self.fc2 = nn.Linear(128, 2)
13
+ self.softmax = nn.Softmax(dim=1)
14
+
15
+ def forward(self, x):
16
+ x = self.fc1(x)
17
+ x = self.relu(x)
18
+ x = self.fc2(x)
19
+ x = self.softmax(x)
20
+ return x
21
+
22
+ # ---------- Load Vectorizer ----------
23
+ vectorizer = joblib.load("model/vectorizer.pkl") # Use joblib for TF-IDF
24
+
25
+ input_dim = len(vectorizer.get_feature_names_out())
26
+
27
+ # ---------- Load Model ----------
28
+ model = SpamClassifier(input_dim)
29
+ model.load_state_dict(torch.load("model/email_spam_classifier.pth", map_location=torch.device("cpu")))
30
+ model.eval()
31
+
32
+ # ---------- Prediction Function ----------
33
+ def predict_email(text):
34
+ X = vectorizer.transform([text]).toarray()
35
+ X_tensor = torch.tensor(X, dtype=torch.float32)
36
+ with torch.no_grad():
37
+ probs = model(X_tensor).numpy()[0]
38
+ labels = ["Ham", "Spam"]
39
+ return {labels[i]: float(probs[i]) for i in range(2)}
40
+
41
+ # ---------- Gradio Interface ----------
42
+ iface = gr.Interface(
43
+ fn=predict_email,
44
+ inputs=gr.Textbox(lines=5, placeholder="Paste your email here..."),
45
+ outputs=gr.Label(num_top_classes=2),
46
+ title="Email Spam Classifier",
47
+ description="Classify emails as Spam or Ham using a PyTorch model."
48
+ )
49
+
50
+ if __name__ == "__main__":
51
+ iface.launch()