yonghan93 commited on
Commit
f8ec48b
·
verified ·
1 Parent(s): 7fefdd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -6,10 +6,21 @@ import json
6
  import pickle
7
  from pytorch_tabular import TabularModel
8
  from sklearn.preprocessing import LabelEncoder
 
 
9
 
10
  # Load model
11
  model = TabularModel.load_model("FTTransformerModel")
12
 
 
 
 
 
 
 
 
 
 
13
  #Encoder
14
  dm = model.datamodule
15
  if hasattr(dm, "label_encoder") and isinstance(dm.label_encoder, LabelEncoder):
 
6
  import pickle
7
  from pytorch_tabular import TabularModel
8
  from sklearn.preprocessing import LabelEncoder
9
+ from omegaconf import OmegaConf, DictConfig
10
+
11
 
12
  # Load model
13
  model = TabularModel.load_model("FTTransformerModel")
14
 
15
+ # Merge default dataloader_kwargs into config
16
+ if isinstance(model.config, DictConfig):
17
+ model.config = OmegaConf.merge(
18
+ model.config,
19
+ {"dataloader_kwargs": {"batch_size": 64, "num_workers": 0}}
20
+ )
21
+ else:
22
+ model.config.dataloader_kwargs = {"batch_size": 64, "num_workers": 0}
23
+
24
  #Encoder
25
  dm = model.datamodule
26
  if hasattr(dm, "label_encoder") and isinstance(dm.label_encoder, LabelEncoder):