Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -26,15 +26,16 @@ import pickle
|
|
| 26 |
from pytorch_tabular import TabularModel
|
| 27 |
from sklearn.preprocessing import LabelEncoder
|
| 28 |
from omegaconf import OmegaConf, DictConfig
|
|
|
|
| 29 |
|
| 30 |
# Load model
|
| 31 |
model = TabularModel.load_model("FTTransformerModel")
|
| 32 |
|
| 33 |
-
model.datamodule._inferred_config = model._inferred_config
|
| 34 |
-
|
| 35 |
dm = model.datamodule
|
| 36 |
if hasattr(dm, "label_encoder") and isinstance(dm.label_encoder, LabelEncoder):
|
| 37 |
dm.label_encoder = [dm.label_encoder]
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Load threshold
|
| 40 |
with open("FTTransformerModel/threshold.json", "r") as f:
|
|
|
|
| 26 |
from pytorch_tabular import TabularModel
|
| 27 |
from sklearn.preprocessing import LabelEncoder
|
| 28 |
from omegaconf import OmegaConf, DictConfig
|
| 29 |
+
from types import SimpleNamespace
|
| 30 |
|
| 31 |
# Load model
|
| 32 |
model = TabularModel.load_model("FTTransformerModel")
|
| 33 |
|
|
|
|
|
|
|
| 34 |
dm = model.datamodule
|
| 35 |
if hasattr(dm, "label_encoder") and isinstance(dm.label_encoder, LabelEncoder):
|
| 36 |
dm.label_encoder = [dm.label_encoder]
|
| 37 |
+
|
| 38 |
+
dm._inferred_config = SimpleNamespace(output_cardinality=[2])
|
| 39 |
|
| 40 |
# Load threshold
|
| 41 |
with open("FTTransformerModel/threshold.json", "r") as f:
|