Pant0x commited on
Commit
6a450c0
·
verified ·
1 Parent(s): 4e9eeac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -3,29 +3,26 @@ import torch
3
  import torch.nn as nn
4
  import joblib
5
 
6
- # ---------- Model Definition (Fixed to match .pth file) ----------
7
  class SpamClassifier(nn.Module):
8
  def __init__(self, input_dim):
9
  super(SpamClassifier, self).__init__()
10
- # Based on the error log, the saved model has:
11
- # Layer 0: Linear(1000 -> 64)
12
- # Layer 2: Linear(64 -> 32)
13
- # Layer 4: Linear(32 -> 2)
14
  self.model = nn.Sequential(
15
  nn.Linear(input_dim, 64),
16
  nn.ReLU(),
17
  nn.Linear(64, 32),
18
  nn.ReLU(),
19
- nn.Linear(32, 2),
20
- nn.Softmax(dim=1)
 
 
21
  )
22
 
23
  def forward(self, x):
24
  return self.model(x)
25
 
26
  # ---------- Load Vectorizer ----------
27
- # Note: Ensure the environment uses scikit-learn version close to 1.2.2
28
- # if possible to avoid the warnings in your logs, though it often works anyway.
29
  vectorizer = joblib.load("model/vectorizer.pkl")
30
  input_dim = len(vectorizer.get_feature_names_out())
31
 
@@ -38,10 +35,13 @@ model.eval()
38
  def predict_email(text):
39
  X = vectorizer.transform([text]).toarray()
40
  X_tensor = torch.tensor(X, dtype=torch.float32)
 
41
  with torch.no_grad():
42
- probs = model(X_tensor).numpy()[0]
43
- labels = ["Ham", "Spam"]
44
- return {labels[i]: float(probs[i]) for i in range(2)}
 
 
45
 
46
  # ---------- Gradio Interface ----------
47
  iface = gr.Interface(
 
3
  import torch.nn as nn
4
  import joblib
5
 
6
+ # ---------- Model Definition (Fixed Architecture) ----------
7
  class SpamClassifier(nn.Module):
8
  def __init__(self, input_dim):
9
  super(SpamClassifier, self).__init__()
 
 
 
 
10
  self.model = nn.Sequential(
11
  nn.Linear(input_dim, 64),
12
  nn.ReLU(),
13
  nn.Linear(64, 32),
14
  nn.ReLU(),
15
+ # CHANGED: Output layer is 1, not 2
16
+ nn.Linear(32, 1),
17
+ # CHANGED: Use Sigmoid for binary output (0 to 1) instead of Softmax
18
+ nn.Sigmoid()
19
  )
20
 
21
  def forward(self, x):
22
  return self.model(x)
23
 
24
  # ---------- Load Vectorizer ----------
25
+ # Ensure you have scikit-learn==1.2.2 in requirements.txt if you get warnings
 
26
  vectorizer = joblib.load("model/vectorizer.pkl")
27
  input_dim = len(vectorizer.get_feature_names_out())
28
 
 
35
  def predict_email(text):
36
  X = vectorizer.transform([text]).toarray()
37
  X_tensor = torch.tensor(X, dtype=torch.float32)
38
+
39
  with torch.no_grad():
40
+ # Get the single probability value (0 = Ham, 1 = Spam)
41
+ prob_spam = model(X_tensor).item()
42
+
43
+ # Manually calculate Ham probability as the complement of Spam
44
+ return {"Ham": 1 - prob_spam, "Spam": prob_spam}
45
 
46
  # ---------- Gradio Interface ----------
47
  iface = gr.Interface(