Spaces:
Build error
Build error
Create src/data_loader.py
Browse files- src/data_loader.py +14 -0
src/data_loader.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
|
| 4 |
+
def get_dataloader(config, tokenizer, split='train'):
|
| 5 |
+
dataset = load_dataset("code_search_net", "python", split=split)
|
| 6 |
+
|
| 7 |
+
def tokenize_function(examples):
|
| 8 |
+
return tokenizer(examples['whole_func_string'], truncation=True, padding='max_length', max_length=config['model']['max_length'])
|
| 9 |
+
|
| 10 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
| 11 |
+
tokenized_dataset = tokenized_dataset.remove_columns(['repo', 'path', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'])
|
| 12 |
+
tokenized_dataset.set_format("torch")
|
| 13 |
+
|
| 14 |
+
return DataLoader(tokenized_dataset, batch_size=config['training']['batch_size'], shuffle=(split == 'train'))
|