tabito12345678910 commited on
Commit
10318b6
·
2 Parent(s): 11e93ed 0fa4db6

Merge remote changes and resolve conflicts

Browse files
Files changed (1) hide show
  1. inference_gohan_cid.py +2 -2
inference_gohan_cid.py CHANGED
@@ -58,10 +58,10 @@ class GohanCIDInferenceEngine:
58
  n_num_features=5, # Updated: 5 numerical features (age ranges are now categorical)
59
  cat_cardinalities=self.cat_cardinalities,
60
  d_out=len(self.all_cids),
61
- d_token=768, # Use the actual saved model's d_token
62
  n_blocks=8,
63
  attention_dropout=0.15,
64
- ffn_d_hidden=768, # Use the actual saved model's ffn_d_hidden
65
  ffn_dropout=0.15,
66
  residual_dropout=0.10
67
  )
 
58
  n_num_features=5, # Updated: 5 numerical features (age ranges are now categorical)
59
  cat_cardinalities=self.cat_cardinalities,
60
  d_out=len(self.all_cids),
61
+ d_token=1024, # Use the actual saved model's d_token
62
  n_blocks=8,
63
  attention_dropout=0.15,
64
+ ffn_d_hidden=1024, # Use the actual saved model's ffn_d_hidden
65
  ffn_dropout=0.15,
66
  residual_dropout=0.10
67
  )