Create data/dataset.py
Browse files- data/dataset.py +100 -0
data/dataset.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
from torch.utils.data import IterableDataset, DataLoader
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from transformers import PreTrainedTokenizerFast
|
| 6 |
+
from typing import Optional, Iterator
|
| 7 |
+
|
| 8 |
+
LANGUAGES = {
|
| 9 |
+
"python": 0.35, "javascript": 0.20, "typescript": 0.15,
|
| 10 |
+
"cpp": 0.10, "rust": 0.08, "go": 0.07, "java": 0.05,
|
| 11 |
+
}
|
| 12 |
+
LANG_TOKEN_MAP = {
|
| 13 |
+
"python": "<|python|>", "javascript": "<|javascript|>",
|
| 14 |
+
"typescript": "<|typescript|>", "cpp": "<|cpp|>",
|
| 15 |
+
"rust": "<|rust|>", "go": "<|go|>", "java": "<|java|>",
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
class TheStackStreamDataset(IterableDataset):
|
| 19 |
+
def __init__(self, tokenizer, max_length=2048, languages=None,
|
| 20 |
+
split="train", max_samples_per_lang=500_000, fim_rate=0.5):
|
| 21 |
+
self.tokenizer = tokenizer
|
| 22 |
+
self.max_length = max_length
|
| 23 |
+
self.languages = languages or list(LANGUAGES.keys())
|
| 24 |
+
self.split = split
|
| 25 |
+
self.max_samples = max_samples_per_lang
|
| 26 |
+
self.fim_rate = fim_rate
|
| 27 |
+
|
| 28 |
+
def _get_lang_dataset(self, lang):
|
| 29 |
+
return load_dataset(
|
| 30 |
+
"bigcode/the-stack", data_dir=f"data/{lang}",
|
| 31 |
+
split=self.split, streaming=True, trust_remote_code=True,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def _tokenize(self, code, lang):
|
| 35 |
+
text = f"{LANG_TOKEN_MAP.get(lang, '')}{code}"
|
| 36 |
+
tokens = self.tokenizer(text, max_length=self.max_length, truncation=True)
|
| 37 |
+
ids = tokens["input_ids"]
|
| 38 |
+
if len(ids) < 64:
|
| 39 |
+
return None
|
| 40 |
+
return {"input_ids": torch.tensor(ids, dtype=torch.long),
|
| 41 |
+
"labels": torch.tensor(ids, dtype=torch.long)}
|
| 42 |
+
|
| 43 |
+
def _apply_fim(self, code):
|
| 44 |
+
if random.random() > self.fim_rate:
|
| 45 |
+
return code
|
| 46 |
+
lines = code.split("\n")
|
| 47 |
+
if len(lines) < 4:
|
| 48 |
+
return code
|
| 49 |
+
start = random.randint(1, len(lines) - 3)
|
| 50 |
+
end = random.randint(start + 1, min(start + 10, len(lines) - 1))
|
| 51 |
+
prefix = "\n".join(lines[:start])
|
| 52 |
+
middle = "\n".join(lines[start:end])
|
| 53 |
+
suffix = "\n".join(lines[end:])
|
| 54 |
+
return f"<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>{middle}"
|
| 55 |
+
|
| 56 |
+
def __iter__(self):
|
| 57 |
+
datasets = {}
|
| 58 |
+
for lang in self.languages:
|
| 59 |
+
try:
|
| 60 |
+
datasets[lang] = iter(self._get_lang_dataset(lang))
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"Warning: could not load {lang}: {e}")
|
| 63 |
+
lang_list = list(datasets.keys())
|
| 64 |
+
weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
|
| 65 |
+
counts = {l: 0 for l in lang_list}
|
| 66 |
+
while lang_list:
|
| 67 |
+
lang = random.choices(lang_list, weights=weights, k=1)[0]
|
| 68 |
+
try:
|
| 69 |
+
sample = next(datasets[lang])
|
| 70 |
+
code = sample.get("content", "")
|
| 71 |
+
if not code.strip():
|
| 72 |
+
continue
|
| 73 |
+
code = self._apply_fim(code)
|
| 74 |
+
item = self._tokenize(code, lang)
|
| 75 |
+
if item:
|
| 76 |
+
counts[lang] += 1
|
| 77 |
+
yield item
|
| 78 |
+
if counts[lang] >= self.max_samples:
|
| 79 |
+
lang_list.remove(lang)
|
| 80 |
+
weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
|
| 81 |
+
except StopIteration:
|
| 82 |
+
lang_list.remove(lang)
|
| 83 |
+
weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
|
| 84 |
+
|
| 85 |
+
class CodeCollator:
|
| 86 |
+
def __init__(self, pad_token_id=0, max_length=2048):
|
| 87 |
+
self.pad_id = pad_token_id
|
| 88 |
+
self.max_length = max_length
|
| 89 |
+
|
| 90 |
+
def __call__(self, batch):
|
| 91 |
+
max_len = min(max(len(x["input_ids"]) for x in batch), self.max_length)
|
| 92 |
+
input_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
|
| 93 |
+
labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
|
| 94 |
+
attention_mask = torch.zeros(len(batch), max_len, dtype=torch.long)
|
| 95 |
+
for i, item in enumerate(batch):
|
| 96 |
+
length = min(len(item["input_ids"]), max_len)
|
| 97 |
+
input_ids[i, :length] = item["input_ids"][:length]
|
| 98 |
+
labels[i, :length] = item["labels"][:length]
|
| 99 |
+
attention_mask[i, :length] = 1
|
| 100 |
+
return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}
|