Search commited on
Commit
eb96110
·
1 Parent(s): 35e5b8a

fix: recognize uniform_small model type in run_experiment

Browse files
Files changed (1) hide show
  1. src/fog/train.py +1 -1
src/fog/train.py CHANGED
@@ -122,7 +122,7 @@ def run_experiment(
122
  eval_loader = DataLoader(eval_ds, batch_size=batch_size)
123
 
124
  # Model
125
- if model_type == "baseline":
126
  model = BaselineTransformer(cfg).to(device)
127
  elif model_type == "motif":
128
  model = MotifTransformer(cfg).to(device)
 
122
  eval_loader = DataLoader(eval_ds, batch_size=batch_size)
123
 
124
  # Model
125
+ if model_type in ("baseline", "uniform_small"):
126
  model = BaselineTransformer(cfg).to(device)
127
  elif model_type == "motif":
128
  model = MotifTransformer(cfg).to(device)