Update modeling.py
Browse files- modeling.py +1 -1
modeling.py
CHANGED
|
@@ -122,7 +122,7 @@ class MARModel(PreTrainedModel):
|
|
| 122 |
num_sampling_steps_diffloss = kwargs.get('num_sampling_steps', 100)
|
| 123 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 124 |
model_type = "mar_base"
|
| 125 |
-
model_architecture =
|
| 126 |
buffer_size=buffer_size,
|
| 127 |
diffloss_d=diffloss_d,
|
| 128 |
diffloss_w=diffloss_w,
|
|
|
|
| 122 |
num_sampling_steps_diffloss = kwargs.get('num_sampling_steps', 100)
|
| 123 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 124 |
model_type = "mar_base"
|
| 125 |
+
model_architecture = mar.__dict__[model_type](
|
| 126 |
buffer_size=buffer_size,
|
| 127 |
diffloss_d=diffloss_d,
|
| 128 |
diffloss_w=diffloss_w,
|