File size: 524 Bytes
8749343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import torch
from safetensors.torch import load_file
import json

class PasswordTesterDataset:
    def __init__(self, safetensors_path, tokenizer_path):
        self.tensors = load_file(safetensors_path)
        with open(tokenizer_path) as f:
            self.tokenizer = json.load(f)
        # flatten tensors into one long tensor
        self.data = torch.cat([t for t in self.tensors.values()], dim=0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]