depthsplat / src /misc /step_tracker.py
Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
from multiprocessing import RLock
import torch
from jaxtyping import Int64
from torch import Tensor
from torch.multiprocessing import Manager
class StepTracker:
lock: RLock
step: Int64[Tensor, ""]
def __init__(self):
self.lock = Manager().RLock()
self.step = torch.tensor(0, dtype=torch.int64).share_memory_()
def set_step(self, step: int) -> None:
with self.lock:
self.step.fill_(step)
def get_step(self) -> int:
with self.lock:
return self.step.item()