yonghan93 commited on
Commit
1d5d00a
·
verified ·
1 Parent(s): ca5a46a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -26,15 +26,16 @@ import pickle
26
  from pytorch_tabular import TabularModel
27
  from sklearn.preprocessing import LabelEncoder
28
  from omegaconf import OmegaConf, DictConfig
 
29
 
30
  # Load model
31
  model = TabularModel.load_model("FTTransformerModel")
32
 
33
- model.datamodule._inferred_config = model._inferred_config
34
-
35
  dm = model.datamodule
36
  if hasattr(dm, "label_encoder") and isinstance(dm.label_encoder, LabelEncoder):
37
  dm.label_encoder = [dm.label_encoder]
 
 
38
 
39
  # Load threshold
40
  with open("FTTransformerModel/threshold.json", "r") as f:
 
26
  from pytorch_tabular import TabularModel
27
  from sklearn.preprocessing import LabelEncoder
28
  from omegaconf import OmegaConf, DictConfig
29
+ from types import SimpleNamespace
30
 
31
  # Load model
32
  model = TabularModel.load_model("FTTransformerModel")
33
 
 
 
34
  dm = model.datamodule
35
  if hasattr(dm, "label_encoder") and isinstance(dm.label_encoder, LabelEncoder):
36
  dm.label_encoder = [dm.label_encoder]
37
+
38
+ dm._inferred_config = SimpleNamespace(output_cardinality=[2])
39
 
40
  # Load threshold
41
  with open("FTTransformerModel/threshold.json", "r") as f: