devoppro commited on
Commit
7ffa6dd
·
verified ·
1 Parent(s): 8c29328

Create data/dataset.py

Browse files
Files changed (1) hide show
  1. data/dataset.py +100 -0
data/dataset.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from torch.utils.data import IterableDataset, DataLoader
4
+ from datasets import load_dataset
5
+ from transformers import PreTrainedTokenizerFast
6
+ from typing import Optional, Iterator
7
+
8
+ LANGUAGES = {
9
+ "python": 0.35, "javascript": 0.20, "typescript": 0.15,
10
+ "cpp": 0.10, "rust": 0.08, "go": 0.07, "java": 0.05,
11
+ }
12
+ LANG_TOKEN_MAP = {
13
+ "python": "<|python|>", "javascript": "<|javascript|>",
14
+ "typescript": "<|typescript|>", "cpp": "<|cpp|>",
15
+ "rust": "<|rust|>", "go": "<|go|>", "java": "<|java|>",
16
+ }
17
+
18
+ class TheStackStreamDataset(IterableDataset):
19
+ def __init__(self, tokenizer, max_length=2048, languages=None,
20
+ split="train", max_samples_per_lang=500_000, fim_rate=0.5):
21
+ self.tokenizer = tokenizer
22
+ self.max_length = max_length
23
+ self.languages = languages or list(LANGUAGES.keys())
24
+ self.split = split
25
+ self.max_samples = max_samples_per_lang
26
+ self.fim_rate = fim_rate
27
+
28
+ def _get_lang_dataset(self, lang):
29
+ return load_dataset(
30
+ "bigcode/the-stack", data_dir=f"data/{lang}",
31
+ split=self.split, streaming=True, trust_remote_code=True,
32
+ )
33
+
34
+ def _tokenize(self, code, lang):
35
+ text = f"{LANG_TOKEN_MAP.get(lang, '')}{code}"
36
+ tokens = self.tokenizer(text, max_length=self.max_length, truncation=True)
37
+ ids = tokens["input_ids"]
38
+ if len(ids) < 64:
39
+ return None
40
+ return {"input_ids": torch.tensor(ids, dtype=torch.long),
41
+ "labels": torch.tensor(ids, dtype=torch.long)}
42
+
43
+ def _apply_fim(self, code):
44
+ if random.random() > self.fim_rate:
45
+ return code
46
+ lines = code.split("\n")
47
+ if len(lines) < 4:
48
+ return code
49
+ start = random.randint(1, len(lines) - 3)
50
+ end = random.randint(start + 1, min(start + 10, len(lines) - 1))
51
+ prefix = "\n".join(lines[:start])
52
+ middle = "\n".join(lines[start:end])
53
+ suffix = "\n".join(lines[end:])
54
+ return f"<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>{middle}"
55
+
56
+ def __iter__(self):
57
+ datasets = {}
58
+ for lang in self.languages:
59
+ try:
60
+ datasets[lang] = iter(self._get_lang_dataset(lang))
61
+ except Exception as e:
62
+ print(f"Warning: could not load {lang}: {e}")
63
+ lang_list = list(datasets.keys())
64
+ weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
65
+ counts = {l: 0 for l in lang_list}
66
+ while lang_list:
67
+ lang = random.choices(lang_list, weights=weights, k=1)[0]
68
+ try:
69
+ sample = next(datasets[lang])
70
+ code = sample.get("content", "")
71
+ if not code.strip():
72
+ continue
73
+ code = self._apply_fim(code)
74
+ item = self._tokenize(code, lang)
75
+ if item:
76
+ counts[lang] += 1
77
+ yield item
78
+ if counts[lang] >= self.max_samples:
79
+ lang_list.remove(lang)
80
+ weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
81
+ except StopIteration:
82
+ lang_list.remove(lang)
83
+ weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
84
+
85
+ class CodeCollator:
86
+ def __init__(self, pad_token_id=0, max_length=2048):
87
+ self.pad_id = pad_token_id
88
+ self.max_length = max_length
89
+
90
+ def __call__(self, batch):
91
+ max_len = min(max(len(x["input_ids"]) for x in batch), self.max_length)
92
+ input_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
93
+ labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
94
+ attention_mask = torch.zeros(len(batch), max_len, dtype=torch.long)
95
+ for i, item in enumerate(batch):
96
+ length = min(len(item["input_ids"]), max_len)
97
+ input_ids[i, :length] = item["input_ids"][:length]
98
+ labels[i, :length] = item["labels"][:length]
99
+ attention_mask[i, :length] = 1
100
+ return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}