Spaces:
Paused
Paused
Search commited on
Commit ·
eb96110
1
Parent(s): 35e5b8a
fix: recognize uniform_small model type in run_experiment
Browse files- src/fog/train.py +1 -1
src/fog/train.py
CHANGED
|
@@ -122,7 +122,7 @@ def run_experiment(
|
|
| 122 |
eval_loader = DataLoader(eval_ds, batch_size=batch_size)
|
| 123 |
|
| 124 |
# Model
|
| 125 |
-
if model_type
|
| 126 |
model = BaselineTransformer(cfg).to(device)
|
| 127 |
elif model_type == "motif":
|
| 128 |
model = MotifTransformer(cfg).to(device)
|
|
|
|
| 122 |
eval_loader = DataLoader(eval_ds, batch_size=batch_size)
|
| 123 |
|
| 124 |
# Model
|
| 125 |
+
if model_type in ("baseline", "uniform_small"):
|
| 126 |
model = BaselineTransformer(cfg).to(device)
|
| 127 |
elif model_type == "motif":
|
| 128 |
model = MotifTransformer(cfg).to(device)
|