jblitzar's picture
Upload folder using huggingface_hub
a8639ac verified
raw
history blame
2 kB
import torchvision.datasets as dset
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
import glob
import os
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader, random_split
class GithubDataset(Dataset):
def __init__(
self,
root_dir=os.path.expanduser("~/torch_datasets/github-python/corpus"),
train=False,
max_length=512,
):
self.root = root_dir
self.file_list = glob.glob(os.path.join(root_dir, "*.*"))
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
self.max_length = max_length
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
path = self.file_list[idx]
with open(path, "r", encoding="utf-8", errors="ignore") as file:
code = file.read()
encoding = self.tokenizer(
code,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
# print(encoding.keys)
return input_ids, attention_mask
dataset = GithubDataset() # root_dir="./test-data/")
dataset = GithubDataset(root_dir="./test-data/")
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
def get_train_dataset():
return train_dataset
def get_test_dataset():
return test_dataset
def get_dataloader(dataset, batch_size=64):
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
if __name__ == "__main__":
d = get_train_dataset()
print("Number of samples: ", len(d))
a, b = d[4]
t = AutoTokenizer.from_pretrained("bert-base-uncased")
for i in a:
print(t.decode(i.item()), end=" ")
print()