|
|
""" |
|
|
Dataset classes for SLM training. |
|
|
|
|
|
Handles loading, preprocessing, and tokenization of conversational data. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import random |
|
|
from typing import List, Dict, Optional, Iterator, Tuple |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, IterableDataset |
|
|
|
|
|
from .tokenizer import SLMTokenizer |
|
|
|
|
|
|
|
|
class ConversationalDataset(Dataset): |
|
|
"""Dataset for conversational/instruction-following data. |
|
|
|
|
|
Loads pre-tokenized data from disk for efficient training. |
|
|
Format: Each sample is a tokenized conversation with user/assistant turns. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_path: str, |
|
|
tokenizer: SLMTokenizer, |
|
|
max_length: int = 1024, |
|
|
split: str = "train", |
|
|
): |
|
|
"""Initialize the dataset. |
|
|
|
|
|
Args: |
|
|
data_path: Path to the processed data directory |
|
|
tokenizer: Tokenizer instance |
|
|
max_length: Maximum sequence length |
|
|
split: "train" or "val" |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.split = split |
|
|
|
|
|
|
|
|
self.samples = self._load_data(data_path) |
|
|
print(f"Loaded {len(self.samples)} samples for {split} split") |
|
|
|
|
|
def _load_data(self, data_path: str) -> List[Dict]: |
|
|
"""Load data from JSON or JSONL files.""" |
|
|
samples = [] |
|
|
|
|
|
|
|
|
split_jsonl = os.path.join(data_path, f"{self.split}.jsonl") |
|
|
if os.path.exists(split_jsonl): |
|
|
with open(split_jsonl, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
samples.append(json.loads(line)) |
|
|
return samples |
|
|
|
|
|
|
|
|
split_file = os.path.join(data_path, f"{self.split}.json") |
|
|
if os.path.exists(split_file): |
|
|
with open(split_file, "r", encoding="utf-8") as f: |
|
|
|
|
|
content = f.read() |
|
|
f.seek(0) |
|
|
try: |
|
|
|
|
|
samples = json.loads(content) |
|
|
if isinstance(samples, list): |
|
|
return samples |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
samples.append(json.loads(line)) |
|
|
return samples |
|
|
|
|
|
|
|
|
combined_file = os.path.join(data_path, "data.json") |
|
|
if os.path.exists(combined_file): |
|
|
with open(combined_file, "r") as f: |
|
|
all_data = json.load(f) |
|
|
if isinstance(all_data, dict) and self.split in all_data: |
|
|
return all_data[self.split] |
|
|
return all_data |
|
|
|
|
|
|
|
|
for ext in ["*.jsonl", "*.json"]: |
|
|
for file in sorted(Path(data_path).glob(ext)): |
|
|
with open(file, "r", encoding="utf-8") as f: |
|
|
if file.suffix == ".jsonl": |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
samples.append(json.loads(line)) |
|
|
else: |
|
|
data = json.load(f) |
|
|
if isinstance(data, list): |
|
|
samples.extend(data) |
|
|
else: |
|
|
samples.append(data) |
|
|
|
|
|
return samples |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
"""Get a single sample. |
|
|
|
|
|
Returns: |
|
|
Dictionary with: |
|
|
- input_ids: Token IDs for the full sequence |
|
|
- attention_mask: 1 for real tokens, 0 for padding |
|
|
- labels: Same as input_ids but with -100 for padding (for loss) |
|
|
""" |
|
|
sample = self.samples[idx] |
|
|
|
|
|
|
|
|
if "input_ids" in sample: |
|
|
|
|
|
input_ids = sample["input_ids"] |
|
|
elif "user" in sample and "assistant" in sample: |
|
|
|
|
|
input_ids = self.tokenizer.encode_conversation( |
|
|
user_message=sample["user"], |
|
|
assistant_message=sample["assistant"], |
|
|
max_length=self.max_length, |
|
|
) |
|
|
elif "text" in sample: |
|
|
|
|
|
input_ids = self.tokenizer.encode( |
|
|
sample["text"], |
|
|
add_special_tokens=True, |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
) |
|
|
elif "question" in sample and "answer" in sample: |
|
|
|
|
|
input_ids = self.tokenizer.encode_conversation( |
|
|
user_message=sample["question"], |
|
|
assistant_message=sample["answer"], |
|
|
max_length=self.max_length, |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unknown sample format: {list(sample.keys())}") |
|
|
|
|
|
|
|
|
if len(input_ids) > self.max_length: |
|
|
input_ids = input_ids[:self.max_length] |
|
|
|
|
|
if input_ids[-1] != self.tokenizer.eos_token_id: |
|
|
input_ids[-1] = self.tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
attention_mask = [1] * len(input_ids) |
|
|
|
|
|
|
|
|
padding_length = self.max_length - len(input_ids) |
|
|
if padding_length > 0: |
|
|
input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length |
|
|
attention_mask = attention_mask + [0] * padding_length |
|
|
|
|
|
|
|
|
|
|
|
labels = [ |
|
|
id if mask == 1 else -100 |
|
|
for id, mask in zip(input_ids, attention_mask) |
|
|
] |
|
|
|
|
|
return { |
|
|
"input_ids": torch.tensor(input_ids, dtype=torch.long), |
|
|
"attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
|
|
"labels": torch.tensor(labels, dtype=torch.long), |
|
|
} |
|
|
|
|
|
|
|
|
class StreamingTextDataset(IterableDataset): |
|
|
"""Streaming dataset for large text files. |
|
|
|
|
|
Memory-efficient dataset that streams data from disk. |
|
|
Useful for training on large text corpora. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_files: List[str], |
|
|
tokenizer: SLMTokenizer, |
|
|
max_length: int = 1024, |
|
|
shuffle: bool = True, |
|
|
seed: int = 42, |
|
|
): |
|
|
"""Initialize streaming dataset. |
|
|
|
|
|
Args: |
|
|
data_files: List of text file paths |
|
|
tokenizer: Tokenizer instance |
|
|
max_length: Maximum sequence length |
|
|
shuffle: Whether to shuffle files and lines |
|
|
seed: Random seed for shuffling |
|
|
""" |
|
|
self.data_files = data_files |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.shuffle = shuffle |
|
|
self.seed = seed |
|
|
|
|
|
|
|
|
for f in data_files: |
|
|
if not os.path.exists(f): |
|
|
raise FileNotFoundError(f"Data file not found: {f}") |
|
|
|
|
|
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: |
|
|
"""Iterate over all samples in all files.""" |
|
|
worker_info = torch.utils.data.get_worker_info() |
|
|
|
|
|
|
|
|
if worker_info is None: |
|
|
files_to_process = self.data_files |
|
|
else: |
|
|
|
|
|
per_worker = len(self.data_files) // worker_info.num_workers |
|
|
worker_id = worker_info.id |
|
|
start = worker_id * per_worker |
|
|
end = start + per_worker if worker_id < worker_info.num_workers - 1 else len(self.data_files) |
|
|
files_to_process = self.data_files[start:end] |
|
|
|
|
|
|
|
|
if self.shuffle: |
|
|
rng = random.Random(self.seed) |
|
|
files_to_process = list(files_to_process) |
|
|
rng.shuffle(files_to_process) |
|
|
|
|
|
|
|
|
buffer = [] |
|
|
buffer_tokens = 0 |
|
|
|
|
|
for file_path in files_to_process: |
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
data = json.loads(line) |
|
|
if "user" in data and "assistant" in data: |
|
|
tokens = self.tokenizer.encode_conversation( |
|
|
data["user"], data["assistant"] |
|
|
) |
|
|
elif "text" in data: |
|
|
tokens = self.tokenizer.encode( |
|
|
data["text"], add_special_tokens=True |
|
|
) |
|
|
else: |
|
|
tokens = self.tokenizer.encode( |
|
|
line, add_special_tokens=True |
|
|
) |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
tokens = self.tokenizer.encode( |
|
|
line, add_special_tokens=True |
|
|
) |
|
|
|
|
|
buffer.extend(tokens) |
|
|
|
|
|
|
|
|
while len(buffer) >= self.max_length: |
|
|
chunk = buffer[:self.max_length] |
|
|
buffer = buffer[self.max_length:] |
|
|
|
|
|
yield self._create_sample(chunk) |
|
|
|
|
|
|
|
|
if len(buffer) > 0: |
|
|
yield self._create_sample(buffer) |
|
|
|
|
|
def _create_sample(self, tokens: List[int]) -> Dict[str, torch.Tensor]: |
|
|
"""Create a training sample from tokens.""" |
|
|
input_ids = tokens[:self.max_length] |
|
|
|
|
|
|
|
|
attention_mask = [1] * len(input_ids) |
|
|
padding_length = self.max_length - len(input_ids) |
|
|
if padding_length > 0: |
|
|
input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length |
|
|
attention_mask = attention_mask + [0] * padding_length |
|
|
|
|
|
labels = [ |
|
|
id if mask == 1 else -100 |
|
|
for id, mask in zip(input_ids, attention_mask) |
|
|
] |
|
|
|
|
|
return { |
|
|
"input_ids": torch.tensor(input_ids, dtype=torch.long), |
|
|
"attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
|
|
"labels": torch.tensor(labels, dtype=torch.long), |
|
|
} |
|
|
|
|
|
|
|
|
class PackedDataset(Dataset): |
|
|
"""Dataset that packs multiple short sequences into one. |
|
|
|
|
|
Efficient for training when samples are shorter than max_length. |
|
|
Concatenates samples with separator tokens to fill sequences. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
samples: List[Dict], |
|
|
tokenizer: SLMTokenizer, |
|
|
max_length: int = 1024, |
|
|
): |
|
|
"""Initialize packed dataset. |
|
|
|
|
|
Args: |
|
|
samples: List of samples with "user" and "assistant" keys |
|
|
tokenizer: Tokenizer instance |
|
|
max_length: Maximum sequence length |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
|
|
|
self.packed_samples = self._pack_sequences(samples) |
|
|
print(f"Packed {len(samples)} samples into {len(self.packed_samples)} sequences") |
|
|
|
|
|
def _pack_sequences(self, samples: List[Dict]) -> List[List[int]]: |
|
|
"""Pack short sequences together.""" |
|
|
packed = [] |
|
|
current_sequence = [] |
|
|
|
|
|
for sample in samples: |
|
|
|
|
|
if "user" in sample and "assistant" in sample: |
|
|
tokens = self.tokenizer.encode_conversation( |
|
|
sample["user"], sample["assistant"] |
|
|
) |
|
|
elif "text" in sample: |
|
|
tokens = self.tokenizer.encode(sample["text"], add_special_tokens=True) |
|
|
else: |
|
|
continue |
|
|
|
|
|
|
|
|
if len(current_sequence) + len(tokens) <= self.max_length: |
|
|
current_sequence.extend(tokens) |
|
|
else: |
|
|
|
|
|
if current_sequence: |
|
|
packed.append(current_sequence) |
|
|
current_sequence = tokens[:self.max_length] |
|
|
|
|
|
|
|
|
if current_sequence: |
|
|
packed.append(current_sequence) |
|
|
|
|
|
return packed |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.packed_samples) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
"""Get a packed sample.""" |
|
|
tokens = self.packed_samples[idx] |
|
|
|
|
|
|
|
|
attention_mask = [1] * len(tokens) |
|
|
padding_length = self.max_length - len(tokens) |
|
|
if padding_length > 0: |
|
|
tokens = tokens + [self.tokenizer.pad_token_id] * padding_length |
|
|
attention_mask = attention_mask + [0] * padding_length |
|
|
|
|
|
labels = [ |
|
|
id if mask == 1 else -100 |
|
|
for id, mask in zip(tokens, attention_mask) |
|
|
] |
|
|
|
|
|
return { |
|
|
"input_ids": torch.tensor(tokens, dtype=torch.long), |
|
|
"attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
|
|
"labels": torch.tensor(labels, dtype=torch.long), |
|
|
} |
|
|
|
|
|
|
|
|
def create_train_val_split( |
|
|
samples: List[Dict], |
|
|
val_ratio: float = 0.01, |
|
|
seed: int = 42, |
|
|
) -> Tuple[List[Dict], List[Dict]]: |
|
|
"""Split samples into train and validation sets. |
|
|
|
|
|
Args: |
|
|
samples: List of all samples |
|
|
val_ratio: Ratio for validation set |
|
|
seed: Random seed |
|
|
|
|
|
Returns: |
|
|
Tuple of (train_samples, val_samples) |
|
|
""" |
|
|
random.seed(seed) |
|
|
shuffled = list(samples) |
|
|
random.shuffle(shuffled) |
|
|
|
|
|
val_size = int(len(shuffled) * val_ratio) |
|
|
val_samples = shuffled[:val_size] |
|
|
train_samples = shuffled[val_size:] |
|
|
|
|
|
return train_samples, val_samples |
|
|
|
|
|
|
|
|
def load_jsonl(file_path: str) -> List[Dict]: |
|
|
"""Load data from a JSONL file.""" |
|
|
samples = [] |
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
samples.append(json.loads(line)) |
|
|
return samples |
|
|
|
|
|
|
|
|
def save_jsonl(samples: List[Dict], file_path: str): |
|
|
"""Save data to a JSONL file.""" |
|
|
with open(file_path, "w", encoding="utf-8") as f: |
|
|
for sample in samples: |
|
|
f.write(json.dumps(sample) + "\n") |
|
|
|