algorythmtechnologies commited on
Commit
a55cadf
·
verified ·
1 Parent(s): 4e8d206

Update supernova/data.py

Browse files
Files changed (1) hide show
  1. supernova/data.py +31 -18
supernova/data.py CHANGED
@@ -8,7 +8,6 @@ from datasets import load_dataset
8
  from transformers import PreTrainedTokenizerBase
9
  import yaml
10
 
11
-
12
  @dataclass
13
  class DataSource:
14
  name: str
@@ -19,7 +18,6 @@ class DataSource:
19
  weight: int = 1
20
  streaming: bool = True
21
 
22
-
23
  def load_sources_from_yaml(path: str) -> List[DataSource]:
24
  with open(path, "r", encoding="utf-8") as f:
25
  cfg = yaml.safe_load(f)
@@ -37,7 +35,6 @@ def load_sources_from_yaml(path: str) -> List[DataSource]:
37
  assert len(srcs) > 0, "No data sources configured"
38
  return srcs
39
 
40
-
41
  def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]:
42
  iters = []
43
  for s in sources:
@@ -45,7 +42,6 @@ def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]:
45
  iters.append(iter(ds))
46
  return iters
47
 
48
-
49
  def weighted_choice(weights: List[int]) -> int:
50
  total = sum(weights)
51
  r = random.randint(1, total)
@@ -56,7 +52,6 @@ def weighted_choice(weights: List[int]) -> int:
56
  return i
57
  return len(weights) - 1
58
 
59
-
60
  class TokenChunkDataset(IterableDataset):
61
  def __init__(
62
  self,
@@ -76,22 +71,35 @@ class TokenChunkDataset(IterableDataset):
76
  iters = build_streams(self.sources)
77
  while True:
78
  i = weighted_choice(self.weights)
79
- def __len__(self):
80
- return 1000000 # enables progress bar if you use one
81
-
82
- def _safe_encode(self, text: str) -> list:
83
- try:
84
- return self.tok.encode(text)
85
- except Exception as e:
86
- print(f"Encoding error for text: {text[:50]}... Error: {e}")
87
- return []
 
 
 
 
 
 
88
  text = row.get(self.sources[i].text_field, None)
89
  if isinstance(text, str) and len(text) > 0:
90
  yield text
91
 
 
 
 
 
 
 
 
92
  def _iter_token_ids(self) -> Iterator[int]:
93
  for text in self._iter_texts():
94
- ids = self.tok.encode(text)
95
  if self.eos_id is not None:
96
  ids.append(self.eos_id)
97
  for t in ids:
@@ -102,7 +110,12 @@ def _safe_encode(self, text: str) -> list:
102
  for tok_id in self._iter_token_ids():
103
  buf.append(tok_id)
104
  while len(buf) >= self.seq_len + 1:
105
- x = torch.tensor(buf[: self.seq_len], dtype=torch.long)
106
- y = torch.tensor(buf[1 : self.seq_len + 1], dtype=torch.long)
107
- del buf[: self.seq_len]
108
  yield x, y
 
 
 
 
 
 
8
  from transformers import PreTrainedTokenizerBase
9
  import yaml
10
 
 
11
  @dataclass
12
  class DataSource:
13
  name: str
 
18
  weight: int = 1
19
  streaming: bool = True
20
 
 
21
  def load_sources_from_yaml(path: str) -> List[DataSource]:
22
  with open(path, "r", encoding="utf-8") as f:
23
  cfg = yaml.safe_load(f)
 
35
  assert len(srcs) > 0, "No data sources configured"
36
  return srcs
37
 
 
38
  def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]:
39
  iters = []
40
  for s in sources:
 
42
  iters.append(iter(ds))
43
  return iters
44
 
 
45
  def weighted_choice(weights: List[int]) -> int:
46
  total = sum(weights)
47
  r = random.randint(1, total)
 
52
  return i
53
  return len(weights) - 1
54
 
 
55
  class TokenChunkDataset(IterableDataset):
56
  def __init__(
57
  self,
 
71
  iters = build_streams(self.sources)
72
  while True:
73
  i = weighted_choice(self.weights)
74
+ try:
75
+ row = next(iters[i])
76
+ except StopIteration:
77
+ try:
78
+ ds = load_dataset(
79
+ self.sources[i].hf_path,
80
+ self.sources[i].hf_name,
81
+ split=self.sources[i].split,
82
+ streaming=self.sources[i].streaming
83
+ )
84
+ iters[i] = iter(ds)
85
+ row = next(iters[i])
86
+ except (StopIteration, Exception) as e:
87
+ print(f"Warning: Could not restart iterator for source {self.sources[i].name}: {e}")
88
+ continue # Skip this iteration and try next source
89
  text = row.get(self.sources[i].text_field, None)
90
  if isinstance(text, str) and len(text) > 0:
91
  yield text
92
 
93
+ def _safe_encode(self, text: str) -> list:
94
+ try:
95
+ return self.tok.encode(text)
96
+ except Exception as e:
97
+ print(f"Encoding error for text: {text[:50]}... Error: {e}")
98
+ return []
99
+
100
  def _iter_token_ids(self) -> Iterator[int]:
101
  for text in self._iter_texts():
102
+ ids = self._safe_encode(text)
103
  if self.eos_id is not None:
104
  ids.append(self.eos_id)
105
  for t in ids:
 
110
  for tok_id in self._iter_token_ids():
111
  buf.append(tok_id)
112
  while len(buf) >= self.seq_len + 1:
113
+ x = torch.tensor(buf[:self.seq_len], dtype=torch.long)
114
+ y = torch.tensor(buf[1:self.seq_len + 1], dtype=torch.long)
115
+ del buf[:self.seq_len]
116
  yield x, y
117
+
118
+ def __len__(self):
119
+ # Provide approximate length for progress tracking
120
+ return 1000000 # Large number for streaming datasets
121
+