algorythmtechnologies commited on
Commit
a38941f
·
verified ·
1 Parent(s): dbf4e29

Update supernova/data.py

Browse files
Files changed (1) hide show
  1. supernova/data.py +108 -105
supernova/data.py CHANGED
@@ -1,105 +1,108 @@
1
- import random
2
- from dataclasses import dataclass
3
- from typing import Dict, Iterable, Iterator, List, Optional, Tuple
4
-
5
- import torch
6
- from torch.utils.data import IterableDataset
7
- from datasets import load_dataset
8
- from transformers import PreTrainedTokenizerBase
9
- import yaml
10
-
11
-
12
- @dataclass
13
- class DataSource:
14
- name: str
15
- hf_path: str
16
- hf_name: Optional[str]
17
- split: str
18
- text_field: str
19
- weight: int = 1
20
- streaming: bool = True
21
-
22
-
23
- def load_sources_from_yaml(path: str) -> List[DataSource]:
24
- with open(path, "r", encoding="utf-8") as f:
25
- cfg = yaml.safe_load(f)
26
- srcs = []
27
- for s in cfg.get("sources", []):
28
- srcs.append(DataSource(
29
- name=s.get("name"),
30
- hf_path=s.get("hf_path"),
31
- hf_name=s.get("hf_name"),
32
- split=s.get("split", "train"),
33
- text_field=s.get("text_field", "text"),
34
- weight=int(s.get("weight", 1)),
35
- streaming=bool(s.get("streaming", True)),
36
- ))
37
- assert len(srcs) > 0, "No data sources configured"
38
- return srcs
39
-
40
-
41
- def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]:
42
- iters = []
43
- for s in sources:
44
- ds = load_dataset(s.hf_path, s.hf_name, split=s.split, streaming=s.streaming)
45
- iters.append(iter(ds))
46
- return iters
47
-
48
-
49
- def weighted_choice(weights: List[int]) -> int:
50
- total = sum(weights)
51
- r = random.randint(1, total)
52
- acc = 0
53
- for i, w in enumerate(weights):
54
- acc += w
55
- if r <= acc:
56
- return i
57
- return len(weights) - 1
58
-
59
-
60
- class TokenChunkDataset(IterableDataset):
61
- def __init__(
62
- self,
63
- tokenizer: PreTrainedTokenizerBase,
64
- sources: List[DataSource],
65
- seq_len: int,
66
- eos_token_id: Optional[int] = None,
67
- ):
68
- super().__init__()
69
- self.tok = tokenizer
70
- self.sources = sources
71
- self.seq_len = seq_len
72
- self.eos_id = eos_token_id if eos_token_id is not None else getattr(tokenizer, "eos_token_id", None)
73
- self.weights = [max(1, s.weight) for s in sources]
74
-
75
- def _iter_texts(self) -> Iterator[str]:
76
- iters = build_streams(self.sources)
77
- while True:
78
- i = weighted_choice(self.weights)
79
- try:
80
- row = next(iters[i])
81
- except StopIteration:
82
- # restart that iterator if streaming was False
83
- iters[i] = build_streams([self.sources[i]])[0]
84
- row = next(iters[i])
85
- text = row.get(self.sources[i].text_field, None)
86
- if isinstance(text, str) and len(text) > 0:
87
- yield text
88
-
89
- def _iter_token_ids(self) -> Iterator[int]:
90
- for text in self._iter_texts():
91
- ids = self.tok.encode(text)
92
- if self.eos_id is not None:
93
- ids.append(self.eos_id)
94
- for t in ids:
95
- yield t
96
-
97
- def __iter__(self):
98
- buf: List[int] = []
99
- for tok_id in self._iter_token_ids():
100
- buf.append(tok_id)
101
- while len(buf) >= self.seq_len + 1:
102
- x = torch.tensor(buf[: self.seq_len], dtype=torch.long)
103
- y = torch.tensor(buf[1 : self.seq_len + 1], dtype=torch.long)
104
- del buf[: self.seq_len]
105
- yield x, y
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Iterable, Iterator, List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch.utils.data import IterableDataset
7
+ from datasets import load_dataset
8
+ from transformers import PreTrainedTokenizerBase
9
+ import yaml
10
+
11
+
12
+ @dataclass
13
+ class DataSource:
14
+ name: str
15
+ hf_path: str
16
+ hf_name: Optional[str]
17
+ split: str
18
+ text_field: str
19
+ weight: int = 1
20
+ streaming: bool = True
21
+
22
+
23
+ def load_sources_from_yaml(path: str) -> List[DataSource]:
24
+ with open(path, "r", encoding="utf-8") as f:
25
+ cfg = yaml.safe_load(f)
26
+ srcs = []
27
+ for s in cfg.get("sources", []):
28
+ srcs.append(DataSource(
29
+ name=s.get("name"),
30
+ hf_path=s.get("hf_path"),
31
+ hf_name=s.get("hf_name"),
32
+ split=s.get("split", "train"),
33
+ text_field=s.get("text_field", "text"),
34
+ weight=int(s.get("weight", 1)),
35
+ streaming=bool(s.get("streaming", True)),
36
+ ))
37
+ assert len(srcs) > 0, "No data sources configured"
38
+ return srcs
39
+
40
+
41
+ def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]:
42
+ iters = []
43
+ for s in sources:
44
+ ds = load_dataset(s.hf_path, s.hf_name, split=s.split, streaming=s.streaming)
45
+ iters.append(iter(ds))
46
+ return iters
47
+
48
+
49
+ def weighted_choice(weights: List[int]) -> int:
50
+ total = sum(weights)
51
+ r = random.randint(1, total)
52
+ acc = 0
53
+ for i, w in enumerate(weights):
54
+ acc += w
55
+ if r <= acc:
56
+ return i
57
+ return len(weights) - 1
58
+
59
+
60
+ class TokenChunkDataset(IterableDataset):
61
+ def __init__(
62
+ self,
63
+ tokenizer: PreTrainedTokenizerBase,
64
+ sources: List[DataSource],
65
+ seq_len: int,
66
+ eos_token_id: Optional[int] = None,
67
+ ):
68
+ super().__init__()
69
+ self.tok = tokenizer
70
+ self.sources = sources
71
+ self.seq_len = seq_len
72
+ self.eos_id = eos_token_id if eos_token_id is not None else getattr(tokenizer, "eos_token_id", None)
73
+ self.weights = [max(1, s.weight) for s in sources]
74
+
75
+ def _iter_texts(self) -> Iterator[str]:
76
+ iters = build_streams(self.sources)
77
+ while True:
78
+ i = weighted_choice(self.weights)
79
+ def __len__(self):
80
+ return 1000000 # enables progress bar if you use one
81
+
82
+ def _safe_encode(self, text: str) -> list:
83
+ try:
84
+ return self.tok.encode(text)
85
+ except Exception as e:
86
+ print(f"Encoding error for text: {text[:50]}... Error: {e}")
87
+ return []
88
+ text = row.get(self.sources[i].text_field, None)
89
+ if isinstance(text, str) and len(text) > 0:
90
+ yield text
91
+
92
+ def _iter_token_ids(self) -> Iterator[int]:
93
+ for text in self._iter_texts():
94
+ ids = self.tok.encode(text)
95
+ if self.eos_id is not None:
96
+ ids.append(self.eos_id)
97
+ for t in ids:
98
+ yield t
99
+
100
+ def __iter__(self):
101
+ buf: List[int] = []
102
+ for tok_id in self._iter_token_ids():
103
+ buf.append(tok_id)
104
+ while len(buf) >= self.seq_len + 1:
105
+ x = torch.tensor(buf[: self.seq_len], dtype=torch.long)
106
+ y = torch.tensor(buf[1 : self.seq_len + 1], dtype=torch.long)
107
+ del buf[: self.seq_len]
108
+ yield x, y