Pant0x commited on
Commit
52e9136
·
verified ·
1 Parent(s): 04cb993

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -3,14 +3,20 @@ import torch
3
  import torch.nn as nn
4
  import joblib
5
 
6
- # ---------- Model Definition (matches saved model) ----------
7
  class SpamClassifier(nn.Module):
8
  def __init__(self, input_dim):
9
  super(SpamClassifier, self).__init__()
 
 
 
 
10
  self.model = nn.Sequential(
11
- nn.Linear(input_dim, 128),
12
  nn.ReLU(),
13
- nn.Linear(128, 2),
 
 
14
  nn.Softmax(dim=1)
15
  )
16
 
@@ -18,6 +24,8 @@ class SpamClassifier(nn.Module):
18
  return self.model(x)
19
 
20
  # ---------- Load Vectorizer ----------
 
 
21
  vectorizer = joblib.load("model/vectorizer.pkl")
22
  input_dim = len(vectorizer.get_feature_names_out())
23
 
@@ -45,4 +53,4 @@ iface = gr.Interface(
45
  )
46
 
47
  if __name__ == "__main__":
48
- iface.launch()
 
3
  import torch.nn as nn
4
  import joblib
5
 
6
+ # ---------- Model Definition (Fixed to match .pth file) ----------
7
  class SpamClassifier(nn.Module):
8
  def __init__(self, input_dim):
9
  super(SpamClassifier, self).__init__()
10
+ # Based on the error log, the saved model has:
11
+ # Layer 0: Linear(1000 -> 64)
12
+ # Layer 2: Linear(64 -> 32)
13
+ # Layer 4: Linear(32 -> 2)
14
  self.model = nn.Sequential(
15
+ nn.Linear(input_dim, 64),
16
  nn.ReLU(),
17
+ nn.Linear(64, 32),
18
+ nn.ReLU(),
19
+ nn.Linear(32, 2),
20
  nn.Softmax(dim=1)
21
  )
22
 
 
24
  return self.model(x)
25
 
26
  # ---------- Load Vectorizer ----------
27
+ # Note: Ensure the environment uses scikit-learn version close to 1.2.2
28
+ # if possible to avoid the warnings in your logs, though it often works anyway.
29
  vectorizer = joblib.load("model/vectorizer.pkl")
30
  input_dim = len(vectorizer.get_feature_names_out())
31
 
 
53
  )
54
 
55
  if __name__ == "__main__":
56
+ iface.launch()