Commit
·
8e2494d
1
Parent(s):
5dc2a44
fix
Browse files
__pycache__/modeling_ministu.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/modeling_ministu.cpython-312.pyc and b/__pycache__/modeling_ministu.cpython-312.pyc differ
|
|
|
configuration_ministu.py
CHANGED
|
@@ -53,4 +53,4 @@ class MiniSTUConfig(PretrainedConfig):
|
|
| 53 |
self.theta = theta
|
| 54 |
self.use_alibi = use_alibi
|
| 55 |
self.torch_dtype = torch_dtype
|
| 56 |
-
self.device = device
|
|
|
|
| 53 |
self.theta = theta
|
| 54 |
self.use_alibi = use_alibi
|
| 55 |
self.torch_dtype = torch_dtype
|
| 56 |
+
self.device = self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
|
modeling_ministu.py
CHANGED
|
@@ -434,6 +434,7 @@ class MiniSTU(PreTrainedModel):
|
|
| 434 |
K=config.num_eigh,
|
| 435 |
use_hankel_L=config.use_hankel_L,
|
| 436 |
device=config.device,
|
|
|
|
| 437 |
)
|
| 438 |
|
| 439 |
self.num_layers = config.num_layers
|
|
|
|
| 434 |
K=config.num_eigh,
|
| 435 |
use_hankel_L=config.use_hankel_L,
|
| 436 |
device=config.device,
|
| 437 |
+
dtype=config.torch_dtype,
|
| 438 |
)
|
| 439 |
|
| 440 |
self.num_layers = config.num_layers
|