File size: 1,997 Bytes
a8639ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torchvision.datasets as dset
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
import glob
import os
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader, random_split
class GithubDataset(Dataset):
def __init__(
self,
root_dir=os.path.expanduser("~/torch_datasets/github-python/corpus"),
train=False,
max_length=512,
):
self.root = root_dir
self.file_list = glob.glob(os.path.join(root_dir, "*.*"))
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
self.max_length = max_length
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
path = self.file_list[idx]
with open(path, "r", encoding="utf-8", errors="ignore") as file:
code = file.read()
encoding = self.tokenizer(
code,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
# print(encoding.keys)
return input_ids, attention_mask
dataset = GithubDataset() # root_dir="./test-data/")
dataset = GithubDataset(root_dir="./test-data/")
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
def get_train_dataset():
return train_dataset
def get_test_dataset():
return test_dataset
def get_dataloader(dataset, batch_size=64):
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
if __name__ == "__main__":
d = get_train_dataset()
print("Number of samples: ", len(d))
a, b = d[4]
t = AutoTokenizer.from_pretrained("bert-base-uncased")
for i in a:
print(t.decode(i.item()), end=" ")
print()
|