chenjgtea commited on
Commit ·
f4f441a
1
Parent(s): 2aa8287
torch 更新
Browse files- tool/ctx.py +4 -0
tool/ctx.py
CHANGED
|
@@ -1,13 +1,17 @@
|
|
| 1 |
import torch
|
|
|
|
|
|
|
| 2 |
|
| 3 |
class TorchSeedContext:
|
| 4 |
def __init__(self, seed):
|
| 5 |
self.seed = seed
|
| 6 |
self.state = None
|
| 7 |
|
|
|
|
| 8 |
def __enter__(self):
|
| 9 |
self.state = torch.random.get_rng_state()
|
| 10 |
torch.manual_seed(self.seed)
|
| 11 |
|
|
|
|
| 12 |
def __exit__(self, type, value, traceback):
|
| 13 |
torch.random.set_rng_state(self.state)
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import spaces
|
| 3 |
+
|
| 4 |
|
| 5 |
class TorchSeedContext:
|
| 6 |
def __init__(self, seed):
|
| 7 |
self.seed = seed
|
| 8 |
self.state = None
|
| 9 |
|
| 10 |
+
@spaces.GPU
|
| 11 |
def __enter__(self):
|
| 12 |
self.state = torch.random.get_rng_state()
|
| 13 |
torch.manual_seed(self.seed)
|
| 14 |
|
| 15 |
+
@spaces.GPU
|
| 16 |
def __exit__(self, type, value, traceback):
|
| 17 |
torch.random.set_rng_state(self.state)
|