File size: 1,997 Bytes
a8639ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torchvision.datasets as dset
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
import glob
import os
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader, random_split


class GithubDataset(Dataset):
    def __init__(
        self,
        root_dir=os.path.expanduser("~/torch_datasets/github-python/corpus"),
        train=False,
        max_length=512,
    ):
        self.root = root_dir
        self.file_list = glob.glob(os.path.join(root_dir, "*.*"))
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.max_length = max_length

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):

        path = self.file_list[idx]

        with open(path, "r", encoding="utf-8", errors="ignore") as file:
            code = file.read()

        encoding = self.tokenizer(
            code,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)

        # print(encoding.keys)

        return input_ids, attention_mask


dataset = GithubDataset()  # root_dir="./test-data/")
dataset = GithubDataset(root_dir="./test-data/")
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])


def get_train_dataset():
    return train_dataset


def get_test_dataset():

    return test_dataset


def get_dataloader(dataset, batch_size=64):

    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


if __name__ == "__main__":
    d = get_train_dataset()
    print("Number of samples: ", len(d))

    a, b = d[4]
    t = AutoTokenizer.from_pretrained("bert-base-uncased")
    for i in a:
        print(t.decode(i.item()), end=" ")
    print()