import torch from safetensors.torch import load_file import json class PasswordTesterDataset: def __init__(self, safetensors_path, tokenizer_path): self.tensors = load_file(safetensors_path) with open(tokenizer_path) as f: self.tokenizer = json.load(f) # flatten tensors into one long tensor self.data = torch.cat([t for t in self.tensors.values()], dim=0) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx]