| import torch | |
| from safetensors.torch import load_file | |
| import json | |
| class PasswordTesterDataset: | |
| def __init__(self, safetensors_path, tokenizer_path): | |
| self.tensors = load_file(safetensors_path) | |
| with open(tokenizer_path) as f: | |
| self.tokenizer = json.load(f) | |
| # flatten tensors into one long tensor | |
| self.data = torch.cat([t for t in self.tensors.values()], dim=0) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |