tabito123 commited on
Commit
d401fb9
·
verified ·
1 Parent(s): 8610011

Update inference_gohan_cid.py

Browse files
Files changed (1) hide show
  1. inference_gohan_cid.py +2 -2
inference_gohan_cid.py CHANGED
@@ -58,10 +58,10 @@ class GohanCIDInferenceEngine:
58
  n_num_features=8,
59
  cat_cardinalities=self.cat_cardinalities,
60
  d_out=len(self.all_cids),
61
- d_token=768, # Use the actual saved model's d_token
62
  n_blocks=8,
63
  attention_dropout=0.15,
64
- ffn_d_hidden=768, # Use the actual saved model's ffn_d_hidden
65
  ffn_dropout=0.15,
66
  residual_dropout=0.10
67
  )
 
58
  n_num_features=8,
59
  cat_cardinalities=self.cat_cardinalities,
60
  d_out=len(self.all_cids),
61
+ d_token=1024, # Use the actual saved model's d_token
62
  n_blocks=8,
63
  attention_dropout=0.15,
64
+ ffn_d_hidden=1024, # Use the actual saved model's ffn_d_hidden
65
  ffn_dropout=0.15,
66
  residual_dropout=0.10
67
  )