yonghan93 commited on
Commit
8e586ae
·
verified ·
1 Parent(s): 6519728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -5,11 +5,16 @@ import torch
5
  import json
6
  import pickle
7
  from pytorch_tabular import TabularModel
8
-
9
 
10
  # Load model
11
  model = TabularModel.load_model("FTTransformerModel")
12
 
 
 
 
 
 
13
  # Load threshold
14
  with open("FTTransformerModel/threshold.json", "r") as f:
15
  threshold = json.load(f)["threshold"]
 
5
  import json
6
  import pickle
7
  from pytorch_tabular import TabularModel
8
+ from sklearn.preprocessing import LabelEncoder
9
 
10
  # Load model
11
  model = TabularModel.load_model("FTTransformerModel")
12
 
13
+ #Encoder
14
+ dm = model.datamodule
15
+ if hasattr(dm, "label_encoder") and isinstance(dm.label_encoder, LabelEncoder):
16
+ dm.label_encoder = [dm.label_encoder]
17
+
18
  # Load threshold
19
  with open("FTTransformerModel/threshold.json", "r") as f:
20
  threshold = json.load(f)["threshold"]